feat: Remove template variables from PromptNode invocation kwargs (#5526)

* Remove template params from kwargs before passing kwargs to invocation layer

* More unit tests

* Add release note

* Enable simple prompt node pipeline integration test use case
This commit is contained in:
Vladimir Blagojevic 2023-08-08 16:40:23 +02:00 committed by GitHub
parent 84ed954c8c
commit 227bf6ca39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 8 deletions

View File

@ -159,7 +159,7 @@ class PromptNode(BaseComponent):
if template_to_fill:
# prompt template used, yield prompts from inputs args
for prompt in template_to_fill.fill(*args, **kwargs):
kwargs_copy = copy.copy(kwargs)
kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs))
# and pass the prepared prompt and kwargs copy to the model
prompt = self.prompt_model._ensure_token_limit(prompt)
prompt_collector.append(prompt)

View File

@ -578,5 +578,19 @@ class PromptTemplate(BasePromptTemplate, ABC):
)
yield prompt_prepared
def remove_template_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Removes template parameters from kwargs.
:param kwargs: Keyword arguments to remove template parameters from.
:return: A modified dictionary with the template parameters removed.
"""
if kwargs:
for param in self.prompt_params:
kwargs.pop(param, None)
return kwargs
else:
return {}
def __repr__(self):
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Remove template variables from invocation layer kwargs

View File

@ -233,19 +233,30 @@ def test_azure_vs_open_ai_invocation_layer_selection():
)
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
@pytest.mark.parametrize("prompt_model", ["hf"], indirect=True)
def test_simple_pipeline(prompt_model):
# TODO: This can be another unit test?
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
"""
Tests that a pipeline with a prompt node and prompt template has the right output structure
"""
output_variable_name = "out"
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable=output_variable_name)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
assert "positive" in result["out"][0].casefold()
# validate output variable present
assert output_variable_name in result
assert len(result[output_variable_name]) == 1
# validate pipeline parameters are present
assert "query" in result
assert "documents" in result
# and that so-called invocation context contains the right keys
assert "invocation_context" in result
assert all(item in result["invocation_context"] for item in ["query", "documents", output_variable_name, "prompts"])
@pytest.mark.skip

View File

@ -468,3 +468,26 @@ class TestPromptTemplateSyntax:
prompt_template = PromptTemplate(prompt_text)
prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)]
assert prompts == expected_prompts
def test_prompt_template_remove_template_params(self):
kwargs = {"query": "query", "documents": "documents", "other": "other"}
expected_kwargs = {"other": "other"}
prompt_text = "Here is prompt text with two variables that are also in kwargs: {query} and {documents}"
prompt_template = PromptTemplate(prompt_text)
assert prompt_template.remove_template_params(kwargs) == expected_kwargs
def test_prompt_template_remove_template_params_edge_cases(self):
"""
Test that the function works with a variety of edge cases
"""
kwargs = {"query": "query", "documents": "documents"}
prompt_text = "Here is prompt text with two variables that are also in kwargs: {query} and {documents}"
prompt_template = PromptTemplate(prompt_text)
assert prompt_template.remove_template_params(kwargs) == {}
assert prompt_template.remove_template_params({}) == {}
assert prompt_template.remove_template_params(None) == {}
totally_unrelated = {"totally_unrelated": "totally_unrelated"}
assert prompt_template.remove_template_params(totally_unrelated) == totally_unrelated