mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 05:37:25 +00:00
feat: Add Literal["*"] option to required_variables in ChatPrompBuilder and PromptBuilder (#8572)
* Add new option for required_variables in PromptBuilder and ChatPromptBuilder * Add reno note * Add tests
This commit is contained in:
parent
b5a2fad642
commit
eace2a99e5
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
@ -100,7 +100,7 @@ class ChatPromptBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
template: Optional[List[ChatMessage]] = None,
|
||||
required_variables: Optional[List[str]] = None,
|
||||
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
|
||||
variables: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
@ -112,7 +112,8 @@ class ChatPromptBuilder:
|
||||
the `init` method` or the `run` method.
|
||||
:param required_variables:
|
||||
List variables that must be provided as input to ChatPromptBuilder.
|
||||
If a variable listed as required is not provided, an exception is raised. Optional.
|
||||
If a variable listed as required is not provided, an exception is raised.
|
||||
If set to "*", all variables found in the prompt are required. Optional.
|
||||
:param variables:
|
||||
List input variables to use in prompt templates instead of the ones inferred from the
|
||||
`template` parameter. For example, to use more variables during prompt engineering than the ones present
|
||||
@ -127,14 +128,15 @@ class ChatPromptBuilder:
|
||||
if template and not variables:
|
||||
for message in template:
|
||||
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
|
||||
# infere variables from template
|
||||
# infer variables from template
|
||||
ast = self._env.parse(message.content)
|
||||
template_variables = meta.find_undeclared_variables(ast)
|
||||
variables += list(template_variables)
|
||||
self.variables = variables
|
||||
|
||||
# setup inputs
|
||||
for var in variables:
|
||||
if var in self.required_variables:
|
||||
for var in self.variables:
|
||||
if self.required_variables == "*" or var in self.required_variables:
|
||||
component.set_input_type(self, var, Any)
|
||||
else:
|
||||
component.set_input_type(self, var, Any, "")
|
||||
@ -211,12 +213,16 @@ class ChatPromptBuilder:
|
||||
:raises ValueError:
|
||||
If no template is provided or if all the required template variables are not provided.
|
||||
"""
|
||||
missing_variables = [var for var in self.required_variables if var not in provided_variables]
|
||||
if self.required_variables == "*":
|
||||
required_variables = sorted(self.variables)
|
||||
else:
|
||||
required_variables = self.required_variables
|
||||
missing_variables = [var for var in required_variables if var not in provided_variables]
|
||||
if missing_variables:
|
||||
missing_vars_str = ", ".join(missing_variables)
|
||||
raise ValueError(
|
||||
f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. "
|
||||
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
|
||||
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
@ -137,7 +137,10 @@ class PromptBuilder:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, template: str, required_variables: Optional[List[str]] = None, variables: Optional[List[str]] = None
|
||||
self,
|
||||
template: str,
|
||||
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
|
||||
variables: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Constructs a PromptBuilder component.
|
||||
@ -150,7 +153,8 @@ class PromptBuilder:
|
||||
unless explicitly specified.
|
||||
If an optional variable is not provided, it's replaced with an empty string in the rendered prompt.
|
||||
:param required_variables: List variables that must be provided as input to PromptBuilder.
|
||||
If a variable listed as required is not provided, an exception is raised. Optional.
|
||||
If a variable listed as required is not provided, an exception is raised.
|
||||
If set to "*", all variables found in the prompt are required. Optional.
|
||||
:param variables:
|
||||
List input variables to use in prompt templates instead of the ones inferred from the
|
||||
`template` parameter. For example, to use more variables during prompt engineering than the ones present
|
||||
@ -173,12 +177,12 @@ class PromptBuilder:
|
||||
ast = self._env.parse(template)
|
||||
template_variables = meta.find_undeclared_variables(ast)
|
||||
variables = list(template_variables)
|
||||
|
||||
variables = variables or []
|
||||
self.variables = variables
|
||||
|
||||
# setup inputs
|
||||
for var in variables:
|
||||
if var in self.required_variables:
|
||||
for var in self.variables:
|
||||
if self.required_variables == "*" or var in self.required_variables:
|
||||
component.set_input_type(self, var, Any)
|
||||
else:
|
||||
component.set_input_type(self, var, Any, "")
|
||||
@ -238,10 +242,14 @@ class PromptBuilder:
|
||||
:raises ValueError:
|
||||
If any of the required template variables is not provided.
|
||||
"""
|
||||
missing_variables = [var for var in self.required_variables if var not in provided_variables]
|
||||
if self.required_variables == "*":
|
||||
required_variables = sorted(self.variables)
|
||||
else:
|
||||
required_variables = self.required_variables
|
||||
missing_variables = [var for var in required_variables if var not in provided_variables]
|
||||
if missing_variables:
|
||||
missing_vars_str = ", ".join(missing_variables)
|
||||
raise ValueError(
|
||||
f"Missing required input variables in PromptBuilder: {missing_vars_str}. "
|
||||
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
|
||||
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
|
||||
)
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added a new option to the required_variables parameter to the PromptBuilder and ChatPromptBuilder.
|
||||
By passing `required_variables="*"` you can automatically set all variables in the prompt to be required.
|
||||
@ -137,6 +137,17 @@ class TestChatPromptBuilder:
|
||||
with pytest.raises(ValueError, match="foo, bar"):
|
||||
builder.run()
|
||||
|
||||
def test_run_with_missing_required_input_using_star(self):
|
||||
builder = ChatPromptBuilder(
|
||||
template=[ChatMessage.from_user("This is a {{ foo }}, not a {{ bar }}")], required_variables="*"
|
||||
)
|
||||
with pytest.raises(ValueError, match="foo"):
|
||||
builder.run(bar="bar")
|
||||
with pytest.raises(ValueError, match="bar"):
|
||||
builder.run(foo="foo")
|
||||
with pytest.raises(ValueError, match="bar, foo"):
|
||||
builder.run()
|
||||
|
||||
def test_run_with_variables(self):
|
||||
variables = ["var1", "var2", "var3"]
|
||||
template = [ChatMessage.from_user("Hello, {{ name }}! {{ var1 }}")]
|
||||
|
||||
@ -143,6 +143,15 @@ class TestPromptBuilder:
|
||||
with pytest.raises(ValueError, match="foo, bar"):
|
||||
builder.run()
|
||||
|
||||
def test_run_with_missing_required_input_using_star(self):
|
||||
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables="*")
|
||||
with pytest.raises(ValueError, match="foo"):
|
||||
builder.run(bar="bar")
|
||||
with pytest.raises(ValueError, match="bar"):
|
||||
builder.run(foo="foo")
|
||||
with pytest.raises(ValueError, match="bar, foo"):
|
||||
builder.run()
|
||||
|
||||
def test_run_with_variables(self):
|
||||
variables = ["var1", "var2", "var3"]
|
||||
template = "Hello, {{ name }}! {{ var1 }}"
|
||||
@ -296,7 +305,7 @@ class TestPromptBuilder:
|
||||
|
||||
assert now_plus_2 == result
|
||||
|
||||
def test_date_with_substraction_offset(self) -> None:
|
||||
def test_date_with_subtraction_offset(self) -> None:
|
||||
template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}"
|
||||
builder = PromptBuilder(template=template)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user