mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
feat: Remove template variables from PromptNode invocation kwargs (#5526)
* Remove template params from kwargs before passing kwargs to invocation layer * More unit tests * Add release note * Enable simple prompt node pipeline integration test use case
This commit is contained in:
parent
84ed954c8c
commit
227bf6ca39
@ -159,7 +159,7 @@ class PromptNode(BaseComponent):
|
||||
if template_to_fill:
|
||||
# prompt template used, yield prompts from inputs args
|
||||
for prompt in template_to_fill.fill(*args, **kwargs):
|
||||
kwargs_copy = copy.copy(kwargs)
|
||||
kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs))
|
||||
# and pass the prepared prompt and kwargs copy to the model
|
||||
prompt = self.prompt_model._ensure_token_limit(prompt)
|
||||
prompt_collector.append(prompt)
|
||||
|
||||
@ -578,5 +578,19 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
)
|
||||
yield prompt_prepared
|
||||
|
||||
def remove_template_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes template parameters from kwargs.
|
||||
|
||||
:param kwargs: Keyword arguments to remove template parameters from.
|
||||
:return: A modified dictionary with the template parameters removed.
|
||||
"""
|
||||
if kwargs:
|
||||
for param in self.prompt_params:
|
||||
kwargs.pop(param, None)
|
||||
return kwargs
|
||||
else:
|
||||
return {}
|
||||
|
||||
def __repr__(self):
|
||||
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Remove template variables from invocation layer kwargs
|
||||
@ -233,19 +233,30 @@ def test_azure_vs_open_ai_invocation_layer_selection():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
@pytest.mark.parametrize("prompt_model", ["hf"], indirect=True)
|
||||
def test_simple_pipeline(prompt_model):
|
||||
# TODO: This can be another unit test?
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||
"""
|
||||
Tests that a pipeline with a prompt node and prompt template has the right output structure
|
||||
"""
|
||||
output_variable_name = "out"
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable=output_variable_name)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
|
||||
assert "positive" in result["out"][0].casefold()
|
||||
|
||||
# validate output variable present
|
||||
assert output_variable_name in result
|
||||
assert len(result[output_variable_name]) == 1
|
||||
|
||||
# validate pipeline parameters are present
|
||||
assert "query" in result
|
||||
assert "documents" in result
|
||||
|
||||
# and that so-called invocation context contains the right keys
|
||||
assert "invocation_context" in result
|
||||
assert all(item in result["invocation_context"] for item in ["query", "documents", output_variable_name, "prompts"])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
|
||||
@ -468,3 +468,26 @@ class TestPromptTemplateSyntax:
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
|
||||
assert prompts == expected_prompts
|
||||
|
||||
def test_prompt_template_remove_template_params(self):
|
||||
kwargs = {"query": "query", "documents": "documents", "other": "other"}
|
||||
expected_kwargs = {"other": "other"}
|
||||
prompt_text = "Here is prompt text with two variables that are also in kwargs: {query} and {documents}"
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
assert prompt_template.remove_template_params(kwargs) == expected_kwargs
|
||||
|
||||
def test_prompt_template_remove_template_params_edge_cases(self):
|
||||
"""
|
||||
Test that the function works with a variety of edge cases
|
||||
"""
|
||||
kwargs = {"query": "query", "documents": "documents"}
|
||||
prompt_text = "Here is prompt text with two variables that are also in kwargs: {query} and {documents}"
|
||||
prompt_template = PromptTemplate(prompt_text)
|
||||
assert prompt_template.remove_template_params(kwargs) == {}
|
||||
|
||||
assert prompt_template.remove_template_params({}) == {}
|
||||
|
||||
assert prompt_template.remove_template_params(None) == {}
|
||||
|
||||
totally_unrelated = {"totally_unrelated": "totally_unrelated"}
|
||||
assert prompt_template.remove_template_params(totally_unrelated) == totally_unrelated
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user