feat: Add Literal["*"] option to required_variables in ChatPrompBuilder and PromptBuilder (#8572)

* Add new option for required_variables in PromptBuilder and ChatPromptBuilder

* Add reno note

* Add tests
This commit is contained in:
Sebastian Husch Lee 2024-11-22 16:27:50 +01:00 committed by GitHub
parent b5a2fad642
commit eace2a99e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 17 deletions

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set, Union
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
@ -100,7 +100,7 @@ class ChatPromptBuilder:
def __init__(
self,
template: Optional[List[ChatMessage]] = None,
required_variables: Optional[List[str]] = None,
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
variables: Optional[List[str]] = None,
):
"""
@ -112,7 +112,8 @@ class ChatPromptBuilder:
the `init` method` or the `run` method.
:param required_variables:
List variables that must be provided as input to ChatPromptBuilder.
If a variable listed as required is not provided, an exception is raised. Optional.
If a variable listed as required is not provided, an exception is raised.
If set to "*", all variables found in the prompt are required. Optional.
:param variables:
List input variables to use in prompt templates instead of the ones inferred from the
`template` parameter. For example, to use more variables during prompt engineering than the ones present
@ -127,14 +128,15 @@ class ChatPromptBuilder:
if template and not variables:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infere variables from template
# infer variables from template
ast = self._env.parse(message.content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
self.variables = variables
# setup inputs
for var in variables:
if var in self.required_variables:
for var in self.variables:
if self.required_variables == "*" or var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")
@ -211,12 +213,16 @@ class ChatPromptBuilder:
:raises ValueError:
If no template is provided or if all the required template variables are not provided.
"""
missing_variables = [var for var in self.required_variables if var not in provided_variables]
if self.required_variables == "*":
required_variables = sorted(self.variables)
else:
required_variables = self.required_variables
missing_variables = [var for var in required_variables if var not in provided_variables]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(
f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. "
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
)
def to_dict(self) -> Dict[str, Any]:

View File

@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set, Union
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
@ -137,7 +137,10 @@ class PromptBuilder:
"""
def __init__(
self, template: str, required_variables: Optional[List[str]] = None, variables: Optional[List[str]] = None
self,
template: str,
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
variables: Optional[List[str]] = None,
):
"""
Constructs a PromptBuilder component.
@ -150,7 +153,8 @@ class PromptBuilder:
unless explicitly specified.
If an optional variable is not provided, it's replaced with an empty string in the rendered prompt.
:param required_variables: List variables that must be provided as input to PromptBuilder.
If a variable listed as required is not provided, an exception is raised. Optional.
If a variable listed as required is not provided, an exception is raised.
If set to "*", all variables found in the prompt are required. Optional.
:param variables:
List input variables to use in prompt templates instead of the ones inferred from the
`template` parameter. For example, to use more variables during prompt engineering than the ones present
@ -173,12 +177,12 @@ class PromptBuilder:
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
variables = variables or []
self.variables = variables
# setup inputs
for var in variables:
if var in self.required_variables:
for var in self.variables:
if self.required_variables == "*" or var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")
@ -238,10 +242,14 @@ class PromptBuilder:
:raises ValueError:
If any of the required template variables is not provided.
"""
missing_variables = [var for var in self.required_variables if var not in provided_variables]
if self.required_variables == "*":
required_variables = sorted(self.variables)
else:
required_variables = self.required_variables
missing_variables = [var for var in required_variables if var not in provided_variables]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(
f"Missing required input variables in PromptBuilder: {missing_vars_str}. "
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
)

View File

@ -0,0 +1,5 @@
---
features:
- |
Added a new option to the required_variables parameter to the PromptBuilder and ChatPromptBuilder.
By passing `required_variables="*"` you can automatically set all variables in the prompt to be required.

View File

@ -137,6 +137,17 @@ class TestChatPromptBuilder:
with pytest.raises(ValueError, match="foo, bar"):
builder.run()
def test_run_with_missing_required_input_using_star(self):
builder = ChatPromptBuilder(
template=[ChatMessage.from_user("This is a {{ foo }}, not a {{ bar }}")], required_variables="*"
)
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="bar, foo"):
builder.run()
def test_run_with_variables(self):
variables = ["var1", "var2", "var3"]
template = [ChatMessage.from_user("Hello, {{ name }}! {{ var1 }}")]

View File

@ -143,6 +143,15 @@ class TestPromptBuilder:
with pytest.raises(ValueError, match="foo, bar"):
builder.run()
def test_run_with_missing_required_input_using_star(self):
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables="*")
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="bar, foo"):
builder.run()
def test_run_with_variables(self):
variables = ["var1", "var2", "var3"]
template = "Hello, {{ name }}! {{ var1 }}"
@ -296,7 +305,7 @@ class TestPromptBuilder:
assert now_plus_2 == result
def test_date_with_substraction_offset(self) -> None:
def test_date_with_subtraction_offset(self) -> None:
template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}"
builder = PromptBuilder(template=template)