From 40360e44ff3dbd5a43d64684beff6f937eecb48a Mon Sep 17 00:00:00 2001 From: Bohan Qu Date: Mon, 29 Apr 2024 20:21:53 +0800 Subject: [PATCH] feat: add required flag for prompt builder inputs (#7553) --- .../components/builders/prompt_builder.py | 22 ++++++++++++++----- ...rompt-builder-inputs-f5d3ffb3cb7df8d0.yaml | 4 ++++ .../builders/test_prompt_builder.py | 10 +++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/add-required-flag-for-prompt-builder-inputs-f5d3ffb3cb7df8d0.yaml diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index cade8fd56..0713b87bd 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Optional from jinja2 import Template, meta @@ -10,8 +10,9 @@ class PromptBuilder: """ PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates. - The template variables found in the template string are used as input types for the component and are all optional. - If a template variable is not provided as an input, it will be replaced with an empty string in the rendered prompt. + The template variables found in the template string are used as input types for the component and are all optional, + unless explicitly specified. If an optional template variable is not provided as an input, it will be replaced with + an empty string in the rendered prompt. Usage example: ```python @@ -21,18 +22,24 @@ class PromptBuilder: ``` """ - def __init__(self, template: str): + def __init__(self, template: str, required_variables: Optional[List[str]] = None): """ Constructs a PromptBuilder component. :param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:" + :param required_variables: An optional list of input variables that must be provided at all times. """ self._template_string = template self.template = Template(template) + self.required_variables = required_variables or [] ast = self.template.environment.parse(template) template_variables = meta.find_undeclared_variables(ast) + for var in template_variables: - component.set_input_type(self, var, Any, "") + if var in self.required_variables: + component.set_input_type(self, var, Any) + else: + component.set_input_type(self, var, Any, "") def to_dict(self) -> Dict[str, Any]: """ @@ -54,4 +61,9 @@ class PromptBuilder: :returns: A dictionary with the following keys: - `prompt`: The updated prompt text after rendering the prompt template. """ + missing_variables = [var for var in self.required_variables if var not in kwargs] + if missing_variables: + missing_vars_str = ", ".join(missing_variables) + raise ValueError(f"Missing required input variables in PromptBuilder: {missing_vars_str}") + return {"prompt": self.template.render(kwargs)} diff --git a/releasenotes/notes/add-required-flag-for-prompt-builder-inputs-f5d3ffb3cb7df8d0.yaml b/releasenotes/notes/add-required-flag-for-prompt-builder-inputs-f5d3ffb3cb7df8d0.yaml new file mode 100644 index 000000000..9aad19db9 --- /dev/null +++ b/releasenotes/notes/add-required-flag-for-prompt-builder-inputs-f5d3ffb3cb7df8d0.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Enhanced PromptBuilder to specify and enforce required variables in prompt templates. diff --git a/test/components/builders/test_prompt_builder.py b/test/components/builders/test_prompt_builder.py index e1c86f7bd..321fa5915 100644 --- a/test/components/builders/test_prompt_builder.py +++ b/test/components/builders/test_prompt_builder.py @@ -33,3 +33,13 @@ def test_run_with_missing_input(): builder = PromptBuilder(template="This is a {{ variable }}") res = builder.run() assert res == {"prompt": "This is a "} + + +def test_run_with_missing_required_input(): + builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables=["foo", "bar"]) + with pytest.raises(ValueError, match="foo"): + builder.run(bar="bar") + with pytest.raises(ValueError, match="bar"): + builder.run(foo="foo") + with pytest.raises(ValueError, match="foo, bar"): + builder.run()