feat: Add dynamic per-user ChatMessage templating support (#6161)

* Add dynamic per-user ChatMessage templating support

* Add unit tests for dynamic templating

* Update add-dynamic-per-message-templating-908468226c5e3d45.yaml

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Proper init ValueError raising, unit tests

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Vladimir Blagojevic 2023-10-24 16:50:45 +02:00 committed by GitHub
parent dd24210908
commit b9b7d7666d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 250 additions and 16 deletions

View File

@ -1,41 +1,140 @@
from typing import Dict, Any
from typing import Dict, Any, Optional, List
from jinja2 import Template, meta
from haystack.preview import component
from haystack.preview import default_to_dict
from haystack.preview.dataclasses.chat_message import ChatMessage, ChatRole
@component
class PromptBuilder:
"""
PromptBuilder is a component that renders a prompt from a template string using Jinja2 engine.
The template variables found in the template string are used as input types for the component and are all required.
A component for building prompts using template strings or template variables.
The `PromptBuilder` can be initialized with a template string or template variables to dynamically build
a prompt. There are two distinct use cases:
1. **Static Template**: When initialized with a `template` string, the `PromptBuilder` will always use
this template for rendering prompts throughout its lifetime.
2. **Dynamic Templates**: When initialized with `template_variables` and receiving messages in
`ChatMessage` format, it allows for different templates for each message, enabling prompt templating
on a per-user message basis.
:param template: (Optional) A template string to be rendered using Jinja2 syntax, e.g.,
"What's the weather like in {{ location }}?". This template will be used for all
prompts. Defaults to None.
:param template_variables: (Optional) A list of all template variables to be used as input.
This parameter enables dynamic templating based on user messages. Defaults to None.
:raises ValueError: If neither `template` nor `template_variables` are provided.
Usage (static templating):
Usage:
```python
template = "Translate the following context to {{ target_language }}. Context: {{ snippet }}; Translation:"
builder = PromptBuilder(template=template)
builder.run(target_language="spanish", snippet="I can't speak spanish.")
```
The above template is used for all messages that are passed to the `PromptBuilder`.
Usage (dynamic templating):
```python
from haystack.preview.dataclasses.chat_message import ChatMessage
template = "What's the weather like in {{ location }}?"
prompt_builder = PromptBuilder(template_variables=["location", "time"])
messages = [ChatMessage.from_system("Always start response to user with Herr Blagojevic.
Respond in German even if some input data is in other languages"),
ChatMessage.from_user(template)]
response = pipe.run(data={"prompt_builder": {"location": location, "messages": messages}})
```
In this example, only the last user message is templated. The template_variables parameter in the PromptBuilder
initialization specifies all potential template variables, yet only the variables utilized in the template are
required in the run method. For instance, since the time variable isn't used in the template, it's not
necessary in the run method invocation above.
Note:
The behavior of `PromptBuilder` is determined by the initialization parameters. A static template
provides a consistent prompt structure, while template variables offer dynamic templating per user
message.
"""
def __init__(self, template: str):
def __init__(self, template: Optional[str] = None, template_variables: Optional[List[str]] = None):
"""
Initialize the component with a template string.
Initialize the component with either a template string or template variables.
If template is given PromptBuilder will parse the template string and use the template variables
as input types. Conversely, if template_variables are given, PromptBuilder will directly use
them as input variables.
:param template: Jinja2 template string, e.g. "Summarize this document: {documents}\nSummary:"
:type template: str
If neither template nor template_variables are provided, an error will be raised. If both are provided,
an error will be raised as well.
:param template: Template string to be rendered.
:param template_variables: List of template variables to be used as input types.
"""
if template_variables and template:
raise ValueError("template and template_variables cannot be provided at the same time.")
# dynamic per-user message templating
if template_variables:
# treat vars as optional input slots
dynamic_input_slots = {var: Optional[Any] for var in template_variables}
self.template = None
# static templating
else:
if not template:
raise ValueError("Either template or template_variables must be provided.")
self.template = Template(template)
ast = self.template.environment.parse(template)
static_template_variables = meta.find_undeclared_variables(ast)
# treat vars as required input slots - as per design
dynamic_input_slots = {var: Any for var in static_template_variables}
# always provide all serialized vars, so we can serialize
# the component regardless of the initialization method (static vs. dynamic)
self.template_variables = template_variables
self._template_string = template
self.template = Template(template)
ast = self.template.environment.parse(template)
template_variables = meta.find_undeclared_variables(ast)
component.set_input_types(self, **{var: Any for var in template_variables})
optional_input_slots = {"messages": Optional[List[ChatMessage]]}
component.set_input_types(self, **optional_input_slots, **dynamic_input_slots)
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, template=self._template_string)
return default_to_dict(self, template=self._template_string, template_variables=self.template_variables)
@component.output_types(prompt=str)
def run(self, **kwargs):
return {"prompt": self.template.render(kwargs)}
def run(self, messages: Optional[List[ChatMessage]] = None, **kwargs):
"""
Build and return the prompt based on the provided messages and template or template variables.
If `messages` are provided, the template will be applied to the last user message.
:param messages: (Optional) List of `ChatMessage` instances, used for dynamic templating
when `template_variables` are provided.
:param kwargs: Additional keyword arguments representing template variables.
"""
if messages:
# apply the template to the last user message only
last_message: ChatMessage = messages[-1]
if last_message.is_from(ChatRole.USER):
template = Template(last_message.content)
return {"prompt": messages[:-1] + [ChatMessage.from_user(template.render(kwargs))]}
else:
return {"prompt": messages}
else:
if self.template:
return {"prompt": self.template.render(kwargs)}
else:
raise ValueError(
"PromptBuilder was initialized with template_variables, but no ChatMessage(s) were provided."
)

View File

@ -0,0 +1,4 @@
---
preview:
- |
Adds `ChatMessage` templating in `PromptBuilder`

View File

@ -1,6 +1,7 @@
import pytest
from haystack.preview.components.builders.prompt_builder import PromptBuilder
from haystack.preview.dataclasses import ChatMessage
@pytest.mark.unit
@ -13,7 +14,10 @@ def test_init():
def test_to_dict():
builder = PromptBuilder(template="This is a {{ variable }}")
res = builder.to_dict()
assert res == {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}}
assert res == {
"type": "PromptBuilder",
"init_parameters": {"template": "This is a {{ variable }}", "template_variables": None},
}
@pytest.mark.unit
@ -35,3 +39,130 @@ def test_run_with_missing_input():
builder = PromptBuilder(template="This is a {{ variable }}")
res = builder.run()
assert res == {"prompt": "This is a "}
@pytest.mark.unit
def test_init_with_template_and_template_variables():
# Initialize the PromptBuilder object with both template and template_variables
with pytest.raises(ValueError, match="template and template_variables cannot be provided at the same time."):
PromptBuilder(template="This is a {{ variable }}", template_variables=["variable"])
@pytest.mark.unit
def test_init_with_no_template_and_no_template_variables():
# Initialize the PromptBuilder object with no template and no template_variables
with pytest.raises(ValueError, match="Either template or template_variables must be provided."):
PromptBuilder()
@pytest.mark.unit
def test_dynamic_template_with_input_variables_no_messages():
# Initialize the PromptBuilder object with dynamic template variables
template_variables = ["location", "time"]
builder = PromptBuilder(template_variables=template_variables)
# Call the run method with input variables
with pytest.raises(ValueError, match="PromptBuilder was initialized with template_variables"):
builder.run(location="New York", time="tomorrow")
@pytest.mark.unit
def test_dynamic_template_with_input_variables_and_messages():
# Initialize the PromptBuilder object with dynamic template variables
template_variables = ["location", "time"]
builder = PromptBuilder(template_variables=template_variables)
system_message = (
"Always start response to user with Herr Blagojevic. "
"Respond in German even if some input data is in other languages"
)
# Create a list of ChatMessage objects
messages = [
ChatMessage.from_system(system_message),
ChatMessage.from_user("What's the weather like in {{ location }}?"),
]
# Call the run method with input variables and messages
result = builder.run(messages=messages, location="New York", time="tomorrow")
# Assert that the prompt is generated correctly
assert result["prompt"] == [
ChatMessage.from_system(system_message),
ChatMessage.from_user("What's the weather like in New York?"),
]
@pytest.mark.unit
def test_static_template_without_input_variables():
# Initialize the PromptBuilder object with a static template and no input variables
template = "Translate the following context to Spanish."
builder = PromptBuilder(template=template)
# Call the run method without input variables
result = builder.run()
# Assert that the prompt is generated correctly
assert result["prompt"] == "Translate the following context to Spanish."
@pytest.mark.unit
def test_dynamic_template_without_input_variables():
# Initialize the PromptBuilder object with dynamic template variables
template_variables = ["location", "time"]
builder = PromptBuilder(template_variables=template_variables)
messages = [ChatMessage.from_user("What's LLM?")]
# Call the run method without input variables
result = builder.run(messages=messages)
# Assert that the prompt is generated correctly
assert result["prompt"] == [ChatMessage.from_user("What's LLM?")]
@pytest.mark.unit
def test_dynamic_template_with_input_variables_and_multiple_user_messages():
# Initialize the PromptBuilder object with dynamic template variables
template_variables = ["location", "time"]
builder = PromptBuilder(template_variables=template_variables)
system_message = (
"Always start response to user with Herr Blagojevic. "
"Respond in German even if some input data is in other languages"
)
# Create a list of ChatMessage objects with multiple user messages
messages = [
ChatMessage.from_system(system_message),
ChatMessage.from_user("Here is improper use of {{ location }} as it is not the last message"),
ChatMessage.from_user("What's the weather like in {{ location }}?"),
]
result = builder.run(messages=messages, location="New York", time="tomorrow")
assert result["prompt"] == [
ChatMessage.from_system(system_message),
ChatMessage.from_user("Here is improper use of {{ location }} as it is not the last message"),
ChatMessage.from_user("What's the weather like in New York?"),
]
def test_dynamic_template_with_invalid_input_variables_and_messages():
# Initialize the PromptBuilder object with dynamic template variables
template_variables = ["location", "time"]
builder = PromptBuilder(template_variables=template_variables)
system_message = (
"Always start response to user with Herr Blagojevic. "
"Respond in German even if some input data is in other languages"
)
# Create a list of ChatMessage objects
messages = [ChatMessage.from_system(system_message), ChatMessage.from_user("What is {{ topic }}?")]
# Call the run method with input variables and messages
result = builder.run(messages=messages, location="New York", time="tomorrow")
# same behaviour as for static template
# as topic is not a template variable, it is ignored
assert result["prompt"] == [ChatMessage.from_system(system_message), ChatMessage.from_user("What is ?")]