mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-30 18:47:25 +00:00
feat: add required flag for prompt builder inputs (#7553)
This commit is contained in:
parent
d2c87b2fd9
commit
40360e44ff
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Template, meta
|
||||
|
||||
@ -10,8 +10,9 @@ class PromptBuilder:
|
||||
"""
|
||||
PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates.
|
||||
|
||||
The template variables found in the template string are used as input types for the component and are all optional.
|
||||
If a template variable is not provided as an input, it will be replaced with an empty string in the rendered prompt.
|
||||
The template variables found in the template string are used as input types for the component and are all optional,
|
||||
unless explicitly specified. If an optional template variable is not provided as an input, it will be replaced with
|
||||
an empty string in the rendered prompt.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
@ -21,18 +22,24 @@ class PromptBuilder:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
def __init__(self, template: str, required_variables: Optional[List[str]] = None):
|
||||
"""
|
||||
Constructs a PromptBuilder component.
|
||||
|
||||
:param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:"
|
||||
:param required_variables: An optional list of input variables that must be provided at all times.
|
||||
"""
|
||||
self._template_string = template
|
||||
self.template = Template(template)
|
||||
self.required_variables = required_variables or []
|
||||
ast = self.template.environment.parse(template)
|
||||
template_variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
for var in template_variables:
|
||||
component.set_input_type(self, var, Any, "")
|
||||
if var in self.required_variables:
|
||||
component.set_input_type(self, var, Any)
|
||||
else:
|
||||
component.set_input_type(self, var, Any, "")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -54,4 +61,9 @@ class PromptBuilder:
|
||||
:returns: A dictionary with the following keys:
|
||||
- `prompt`: The updated prompt text after rendering the prompt template.
|
||||
"""
|
||||
missing_variables = [var for var in self.required_variables if var not in kwargs]
|
||||
if missing_variables:
|
||||
missing_vars_str = ", ".join(missing_variables)
|
||||
raise ValueError(f"Missing required input variables in PromptBuilder: {missing_vars_str}")
|
||||
|
||||
return {"prompt": self.template.render(kwargs)}
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Enhanced PromptBuilder to specify and enforce required variables in prompt templates.
|
@ -33,3 +33,13 @@ def test_run_with_missing_input():
|
||||
builder = PromptBuilder(template="This is a {{ variable }}")
|
||||
res = builder.run()
|
||||
assert res == {"prompt": "This is a "}
|
||||
|
||||
|
||||
def test_run_with_missing_required_input():
|
||||
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables=["foo", "bar"])
|
||||
with pytest.raises(ValueError, match="foo"):
|
||||
builder.run(bar="bar")
|
||||
with pytest.raises(ValueError, match="bar"):
|
||||
builder.run(foo="foo")
|
||||
with pytest.raises(ValueError, match="foo, bar"):
|
||||
builder.run()
|
||||
|
Loading…
x
Reference in New Issue
Block a user