diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index a5e8be2f9..060bec69c 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -159,7 +159,7 @@ class PromptNode(BaseComponent): if template_to_fill: # prompt template used, yield prompts from inputs args for prompt in template_to_fill.fill(*args, **kwargs): - kwargs_copy = copy.copy(kwargs) + kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs)) # and pass the prepared prompt and kwargs copy to the model prompt = self.prompt_model._ensure_token_limit(prompt) prompt_collector.append(prompt) diff --git a/haystack/nodes/prompt/prompt_template.py b/haystack/nodes/prompt/prompt_template.py index e757a0aea..7717885be 100644 --- a/haystack/nodes/prompt/prompt_template.py +++ b/haystack/nodes/prompt/prompt_template.py @@ -578,5 +578,19 @@ class PromptTemplate(BasePromptTemplate, ABC): ) yield prompt_prepared + def remove_template_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Removes template parameters from kwargs. + + :param kwargs: Keyword arguments to remove template parameters from. + :return: A modified dictionary with the template parameters removed. + """ + if kwargs: + for param in self.prompt_params: + kwargs.pop(param, None) + return kwargs + else: + return {} + def __repr__(self): return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})" diff --git a/releasenotes/notes/remove-template-vars-invocation-kwargs-060f186fd1250fe4.yaml b/releasenotes/notes/remove-template-vars-invocation-kwargs-060f186fd1250fe4.yaml new file mode 100644 index 000000000..d2c6276ab --- /dev/null +++ b/releasenotes/notes/remove-template-vars-invocation-kwargs-060f186fd1250fe4.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Remove template variables from invocation layer kwargs diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 1a6c4d61c..c305b799b 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -233,19 +233,30 @@ def test_azure_vs_open_ai_invocation_layer_selection(): ) -@pytest.mark.skip @pytest.mark.integration -@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True) +@pytest.mark.parametrize("prompt_model", ["hf"], indirect=True) def test_simple_pipeline(prompt_model): - # TODO: This can be another unit test? - skip_test_for_invalid_key(prompt_model) - - node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out") + """ + Tests that a pipeline with a prompt node and prompt template has the right output structure + """ + output_variable_name = "out" + node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable=output_variable_name) pipe = Pipeline() pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) - assert "positive" in result["out"][0].casefold() + + # validate output variable present + assert output_variable_name in result + assert len(result[output_variable_name]) == 1 + + # validate pipeline parameters are present + assert "query" in result + assert "documents" in result + + # and that so-called invocation context contains the right keys + assert "invocation_context" in result + assert all(item in result["invocation_context"] for item in ["query", "documents", output_variable_name, "prompts"]) @pytest.mark.skip diff --git a/test/prompt/test_prompt_template.py b/test/prompt/test_prompt_template.py index ff627738b..cf8254d3d 100644 --- a/test/prompt/test_prompt_template.py +++ b/test/prompt/test_prompt_template.py @@ -468,3 +468,26 @@ class TestPromptTemplateSyntax: prompt_template = PromptTemplate(prompt_text) prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)] assert prompts == expected_prompts + + def test_prompt_template_remove_template_params(self): + kwargs = {"query": "query", "documents": "documents", "other": "other"} + expected_kwargs = {"other": "other"} + prompt_text = "Here is prompt text with two variables that are also in kwargs: {query} and {documents}" + prompt_template = PromptTemplate(prompt_text) + assert prompt_template.remove_template_params(kwargs) == expected_kwargs + + def test_prompt_template_remove_template_params_edge_cases(self): + """ + Test that the function works with a variety of edge cases + """ + kwargs = {"query": "query", "documents": "documents"} + prompt_text = "Here is prompt text with two variables that are also in kwargs: {query} and {documents}" + prompt_template = PromptTemplate(prompt_text) + assert prompt_template.remove_template_params(kwargs) == {} + + assert prompt_template.remove_template_params({}) == {} + + assert prompt_template.remove_template_params(None) == {} + + totally_unrelated = {"totally_unrelated": "totally_unrelated"} + assert prompt_template.remove_template_params(totally_unrelated) == totally_unrelated