mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-01 11:06:45 +00:00
feat: add required flag for prompt builder inputs (#7553)
This commit is contained in:
parent
d2c87b2fd9
commit
40360e44ff
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from jinja2 import Template, meta
|
from jinja2 import Template, meta
|
||||||
|
|
||||||
@ -10,8 +10,9 @@ class PromptBuilder:
|
|||||||
"""
|
"""
|
||||||
PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates.
|
PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates.
|
||||||
|
|
||||||
The template variables found in the template string are used as input types for the component and are all optional.
|
The template variables found in the template string are used as input types for the component and are all optional,
|
||||||
If a template variable is not provided as an input, it will be replaced with an empty string in the rendered prompt.
|
unless explicitly specified. If an optional template variable is not provided as an input, it will be replaced with
|
||||||
|
an empty string in the rendered prompt.
|
||||||
|
|
||||||
Usage example:
|
Usage example:
|
||||||
```python
|
```python
|
||||||
@ -21,18 +22,24 @@ class PromptBuilder:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, template: str):
|
def __init__(self, template: str, required_variables: Optional[List[str]] = None):
|
||||||
"""
|
"""
|
||||||
Constructs a PromptBuilder component.
|
Constructs a PromptBuilder component.
|
||||||
|
|
||||||
:param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:"
|
:param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:"
|
||||||
|
:param required_variables: An optional list of input variables that must be provided at all times.
|
||||||
"""
|
"""
|
||||||
self._template_string = template
|
self._template_string = template
|
||||||
self.template = Template(template)
|
self.template = Template(template)
|
||||||
|
self.required_variables = required_variables or []
|
||||||
ast = self.template.environment.parse(template)
|
ast = self.template.environment.parse(template)
|
||||||
template_variables = meta.find_undeclared_variables(ast)
|
template_variables = meta.find_undeclared_variables(ast)
|
||||||
|
|
||||||
for var in template_variables:
|
for var in template_variables:
|
||||||
component.set_input_type(self, var, Any, "")
|
if var in self.required_variables:
|
||||||
|
component.set_input_type(self, var, Any)
|
||||||
|
else:
|
||||||
|
component.set_input_type(self, var, Any, "")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -54,4 +61,9 @@ class PromptBuilder:
|
|||||||
:returns: A dictionary with the following keys:
|
:returns: A dictionary with the following keys:
|
||||||
- `prompt`: The updated prompt text after rendering the prompt template.
|
- `prompt`: The updated prompt text after rendering the prompt template.
|
||||||
"""
|
"""
|
||||||
|
missing_variables = [var for var in self.required_variables if var not in kwargs]
|
||||||
|
if missing_variables:
|
||||||
|
missing_vars_str = ", ".join(missing_variables)
|
||||||
|
raise ValueError(f"Missing required input variables in PromptBuilder: {missing_vars_str}")
|
||||||
|
|
||||||
return {"prompt": self.template.render(kwargs)}
|
return {"prompt": self.template.render(kwargs)}
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Enhanced PromptBuilder to specify and enforce required variables in prompt templates.
|
@ -33,3 +33,13 @@ def test_run_with_missing_input():
|
|||||||
builder = PromptBuilder(template="This is a {{ variable }}")
|
builder = PromptBuilder(template="This is a {{ variable }}")
|
||||||
res = builder.run()
|
res = builder.run()
|
||||||
assert res == {"prompt": "This is a "}
|
assert res == {"prompt": "This is a "}
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_missing_required_input():
|
||||||
|
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables=["foo", "bar"])
|
||||||
|
with pytest.raises(ValueError, match="foo"):
|
||||||
|
builder.run(bar="bar")
|
||||||
|
with pytest.raises(ValueError, match="bar"):
|
||||||
|
builder.run(foo="foo")
|
||||||
|
with pytest.raises(ValueError, match="foo, bar"):
|
||||||
|
builder.run()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user