feat: add required flag for prompt builder inputs (#7553)

This commit is contained in:
Bohan Qu 2024-04-29 20:21:53 +08:00 committed by GitHub
parent d2c87b2fd9
commit 40360e44ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from jinja2 import Template, meta
@ -10,8 +10,9 @@ class PromptBuilder:
"""
PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates.
The template variables found in the template string are used as input types for the component and are all optional.
If a template variable is not provided as an input, it will be replaced with an empty string in the rendered prompt.
The template variables found in the template string are used as input types for the component and are all optional,
unless explicitly specified. If an optional template variable is not provided as an input, it will be replaced with
an empty string in the rendered prompt.
Usage example:
```python
@ -21,17 +22,23 @@ class PromptBuilder:
```
"""
def __init__(self, template: str):
def __init__(self, template: str, required_variables: Optional[List[str]] = None):
"""
Constructs a PromptBuilder component.
:param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:"
:param required_variables: An optional list of input variables that must be provided at all times.
"""
self._template_string = template
self.template = Template(template)
self.required_variables = required_variables or []
ast = self.template.environment.parse(template)
template_variables = meta.find_undeclared_variables(ast)
for var in template_variables:
if var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")
def to_dict(self) -> Dict[str, Any]:
@ -54,4 +61,9 @@ class PromptBuilder:
:returns: A dictionary with the following keys:
- `prompt`: The updated prompt text after rendering the prompt template.
"""
missing_variables = [var for var in self.required_variables if var not in kwargs]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(f"Missing required input variables in PromptBuilder: {missing_vars_str}")
return {"prompt": self.template.render(kwargs)}

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Enhanced PromptBuilder to specify and enforce required variables in prompt templates.

View File

@ -33,3 +33,13 @@ def test_run_with_missing_input():
builder = PromptBuilder(template="This is a {{ variable }}")
res = builder.run()
assert res == {"prompt": "This is a "}
def test_run_with_missing_required_input():
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables=["foo", "bar"])
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="foo, bar"):
builder.run()