feat: add required flag for prompt builder inputs (#7553)

This commit is contained in:
Bohan Qu 2024-04-29 20:21:53 +08:00 committed by GitHub
parent d2c87b2fd9
commit 40360e44ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict from typing import Any, Dict, List, Optional
from jinja2 import Template, meta from jinja2 import Template, meta
@ -10,8 +10,9 @@ class PromptBuilder:
""" """
PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates. PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates.
The template variables found in the template string are used as input types for the component and are all optional. The template variables found in the template string are used as input types for the component and are all optional,
If a template variable is not provided as an input, it will be replaced with an empty string in the rendered prompt. unless explicitly specified. If an optional template variable is not provided as an input, it will be replaced with
an empty string in the rendered prompt.
Usage example: Usage example:
```python ```python
@ -21,17 +22,23 @@ class PromptBuilder:
``` ```
""" """
def __init__(self, template: str): def __init__(self, template: str, required_variables: Optional[List[str]] = None):
""" """
Constructs a PromptBuilder component. Constructs a PromptBuilder component.
:param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:" :param template: A Jinja2 template string, e.g. "Summarize this document: {documents}\\nSummary:"
:param required_variables: An optional list of input variables that must be provided at all times.
""" """
self._template_string = template self._template_string = template
self.template = Template(template) self.template = Template(template)
self.required_variables = required_variables or []
ast = self.template.environment.parse(template) ast = self.template.environment.parse(template)
template_variables = meta.find_undeclared_variables(ast) template_variables = meta.find_undeclared_variables(ast)
for var in template_variables: for var in template_variables:
if var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "") component.set_input_type(self, var, Any, "")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@ -54,4 +61,9 @@ class PromptBuilder:
:returns: A dictionary with the following keys: :returns: A dictionary with the following keys:
- `prompt`: The updated prompt text after rendering the prompt template. - `prompt`: The updated prompt text after rendering the prompt template.
""" """
missing_variables = [var for var in self.required_variables if var not in kwargs]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(f"Missing required input variables in PromptBuilder: {missing_vars_str}")
return {"prompt": self.template.render(kwargs)} return {"prompt": self.template.render(kwargs)}

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Enhanced PromptBuilder to specify and enforce required variables in prompt templates.

View File

@ -33,3 +33,13 @@ def test_run_with_missing_input():
builder = PromptBuilder(template="This is a {{ variable }}") builder = PromptBuilder(template="This is a {{ variable }}")
res = builder.run() res = builder.run()
assert res == {"prompt": "This is a "} assert res == {"prompt": "This is a "}
def test_run_with_missing_required_input():
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables=["foo", "bar"])
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="foo, bar"):
builder.run()