feat: DynamicChatPromptBuilder add templating to all user/system messages (#7423)

This commit is contained in:
Vladimir Blagojevic 2024-03-27 15:34:50 +01:00 committed by GitHub
parent 7894024e6f
commit ce8e114769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 193 additions and 58 deletions

View File

@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class DynamicChatPromptBuilder:
"""
DynamicChatPromptBuilder is designed to construct dynamic prompts from a list of `ChatMessage` instances. It
integrates with Jinja2 templating for dynamic prompt generation. It assumes that the last user message in the list
contains a template and renders it with variables provided to the constructor. Additional template variables
can be feed into the pipeline `run` method and will be merged before rendering the template.
integrates with Jinja2 templating for dynamic prompt generation. It considers any user or system message in the list
potentially containing a template and renders it with variables provided to the constructor. Additional template
variables can be feed into the component/pipeline `run` method and will be merged before rendering the template.
Usage example:
```python
@ -34,11 +34,12 @@ class DynamicChatPromptBuilder:
pipe.connect("prompt_builder.prompt", "llm.messages")
location = "Berlin"
system_message = ChatMessage.from_system("You are a helpful assistant giving out valuable information to tourists.")
language = "English"
system_message = ChatMessage.from_system("You are an assistant giving information to tourists in {{language}}")
messages = [system_message, ChatMessage.from_user("Tell me about {{location}}")]
res = pipe.run(data={"prompt_builder": {"template_variables": {"location": location}, "prompt_source": messages}})
res = pipe.run(data={"prompt_builder": {"template_variables": {"location": location, "language": language},
"prompt_source": messages}})
print(res)
>> {'llm': {'replies': [ChatMessage(content="Berlin is the capital city of Germany and one of the most vibrant
@ -91,48 +92,22 @@ class DynamicChatPromptBuilder:
def run(self, prompt_source: List[ChatMessage], template_variables: Optional[Dict[str, Any]] = None, **kwargs):
"""
Executes the dynamic prompt building process by processing a list of `ChatMessage` instances.
The last user message is treated as a template and rendered with the variables provided to the constructor.
You can provide additional template variables directly to this method, which are then merged with the variables
provided to the constructor.
Any user message or system message is inspected for templates and rendered with the variables provided to the
constructor. You can provide additional template variables directly to this method, which are then merged with
the variables provided to the constructor.
:param prompt_source:
A list of `ChatMessage` instances. We make an assumption that the last user message has
the template for the chat prompt
A list of `ChatMessage` instances. All user and system messages are treated as potentially having templates
and are rendered with the provided template variables - if templates are found.
:param template_variables:
A dictionary of template variables. Template variables provided at initialization are required
to resolve pipeline variables, and these are additional variables users can provide directly to this method.
:param kwargs:
Additional keyword arguments, typically resolved from a pipeline, which are merged with the provided template variables.
Additional keyword arguments, typically resolved from a pipeline, which are merged with the provided
template variables.
:returns: A dictionary with the following keys:
- `prompt`: The updated list of `ChatMessage` instances after rendering the string template.
"""
kwargs = kwargs or {}
template_variables = template_variables or {}
template_variables_combined = {**kwargs, **template_variables}
if not template_variables_combined:
logger.warning(
"The DynamicChatPromptBuilder run method requires template variables, but none were provided. "
"Please provide an appropriate template variable to enable correct prompt generation."
)
result: List[ChatMessage] = self._process_chat_messages(prompt_source, template_variables_combined)
return {"prompt": result}
def _process_chat_messages(self, prompt_source: List[ChatMessage], template_variables: Dict[str, Any]):
"""
Processes a list of :class:`ChatMessage` instances to generate a chat prompt.
It takes the last user message in the list, treats it as a template, and renders it with the provided
template variables. The resulting message replaces the last user message in the list, forming a complete,
templated chat prompt.
:param prompt_source:
A list of `ChatMessage` instances to be processed. The last message is expected
to be from a user and is treated as a template.
:param template_variables:
A dictionary of template variables used for rendering the last user message.
:returns:
A list of `ChatMessage` instances, where the last user message has been replaced with its
- `prompt`: The updated list of `ChatMessage` instances after rendering the found templates.
:raises ValueError:
If `chat_messages` is empty or contains elements that are not instances of `ChatMessage`.
:raises ValueError:
@ -150,17 +125,28 @@ class DynamicChatPromptBuilder:
f"are ChatMessage instances."
)
last_message: ChatMessage = prompt_source[-1]
if last_message.is_from(ChatRole.USER):
template = self._validate_template(last_message.content, set(template_variables.keys()))
templated_user_message = ChatMessage.from_user(template.render(template_variables))
return prompt_source[:-1] + [templated_user_message]
else:
kwargs = kwargs or {}
template_variables = template_variables or {}
template_variables = {**kwargs, **template_variables}
if not template_variables:
logger.warning(
"DynamicChatPromptBuilder was not provided with a user message as the last message in "
"chat conversation, no templating will be applied."
"The DynamicChatPromptBuilder run method requires template variables, but none were provided. "
"Please provide an appropriate template variable to enable correct prompt generation."
)
return prompt_source
processed_messages = []
for message in prompt_source:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
template = self._validate_template(message.content, set(template_variables.keys()))
rendered_content = template.render(template_variables)
rendered_message = (
ChatMessage.from_user(rendered_content)
if message.is_from(ChatRole.USER)
else ChatMessage.from_system(rendered_content)
)
processed_messages.append(rendered_message)
else:
processed_messages.append(message)
return {"prompt": processed_messages}
def _validate_template(self, template_text: str, provided_variables: Set[str]):
"""

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Enhanced DynamicChatPromptBuilder's capabilities by allowing all user and system messages to be templated with provided variables. This update ensures a more versatile and dynamic templating process, making chat prompt generation more efficient and customized to user needs.

View File

@ -2,7 +2,7 @@ from typing import List
import pytest
from haystack import Pipeline
from haystack import Pipeline, component
from haystack.components.builders import DynamicChatPromptBuilder
from haystack.dataclasses import ChatMessage
@ -42,23 +42,21 @@ class TestDynamicChatPromptBuilder:
prompt_source = [ChatMessage.from_user(content="Hello, {{ who }}!")]
template_variables = {"who": "World"}
result = prompt_builder._process_chat_messages(prompt_source, template_variables)
result = prompt_builder.run(prompt_source, template_variables)
assert result == [ChatMessage.from_user(content="Hello, World!")]
assert result == {"prompt": [ChatMessage.from_user(content="Hello, World!")]}
def test_empty_chat_message_list(self):
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
with pytest.raises(ValueError):
prompt_builder._process_chat_messages(prompt_source=[], template_variables={})
prompt_builder.run(prompt_source=[], template_variables={})
def test_chat_message_list_with_mixed_object_list(self):
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
with pytest.raises(ValueError):
prompt_builder._process_chat_messages(
prompt_source=[ChatMessage.from_user("Hello"), "there world"], template_variables={}
)
prompt_builder.run(prompt_source=[ChatMessage.from_user("Hello"), "there world"], template_variables={})
def test_chat_message_list_with_missing_variables(self):
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
@ -66,7 +64,7 @@ class TestDynamicChatPromptBuilder:
# Call the _process_chat_messages method and expect a ValueError
with pytest.raises(ValueError):
prompt_builder._process_chat_messages(prompt_source, template_variables={})
prompt_builder.run(prompt_source, template_variables={})
def test_missing_template_variables(self):
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
@ -92,8 +90,69 @@ class TestDynamicChatPromptBuilder:
# provided variables are a superset of the required variables
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city", "age"})
def test_multiple_templated_chat_messages(self):
prompt_builder = DynamicChatPromptBuilder()
language = "French"
location = "Berlin"
messages = [
ChatMessage.from_system("Write your response in this language:{{language}}"),
ChatMessage.from_user("Tell me about {{location}}"),
]
result = prompt_builder.run(
template_variables={"language": language, "location": location}, prompt_source=messages
)
assert result["prompt"] == [
ChatMessage.from_system("Write your response in this language:French"),
ChatMessage.from_user("Tell me about Berlin"),
], "The templated messages should match the expected output."
def test_multiple_templated_chat_messages_in_place(self):
prompt_builder = DynamicChatPromptBuilder()
language = "French"
location = "Berlin"
messages = [
ChatMessage.from_system("Write your response ins this language:{{language}}"),
ChatMessage.from_user("Tell me about {{location}}"),
]
res = prompt_builder.run(
template_variables={"language": language, "location": location}, prompt_source=messages
)
assert res == {
"prompt": [
ChatMessage.from_system("Write your response ins this language:French"),
ChatMessage.from_user("Tell me about Berlin"),
]
}, "The templated messages should match the expected output."
def test_some_templated_chat_messages(self):
prompt_builder = DynamicChatPromptBuilder()
language = "English"
location = "Paris"
messages = [
ChatMessage.from_system("Please, respond in the following language: {{language}}."),
ChatMessage.from_user("I would like to learn more about {{location}}."),
ChatMessage.from_assistant("Yes, I can help you with that {{subject}}"),
ChatMessage.from_user("Ok so do so please, be elaborate."),
]
result = prompt_builder.run(
template_variables={"language": language, "location": location}, prompt_source=messages
)
expected_messages = [
ChatMessage.from_system("Please, respond in the following language: English."),
ChatMessage.from_user("I would like to learn more about Paris."),
ChatMessage.from_assistant(
"Yes, I can help you with that {{subject}}"
), # assistant message should not be templated
ChatMessage.from_user("Ok so do so please, be elaborate."),
]
assert result["prompt"] == expected_messages, "The templated messages should match the expected output."
def test_example_in_pipeline(self):
# no parameter init, we don't use any runtime template variables
prompt_builder = DynamicChatPromptBuilder()
pipe = Pipeline()
@ -138,3 +197,89 @@ class TestDynamicChatPromptBuilder:
]
}
}
def test_example_in_pipeline_with_multiple_templated_messages(self):
# no parameter init, we don't use any runtime template variables
prompt_builder = DynamicChatPromptBuilder()
pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)
location = "Berlin"
system_message = ChatMessage.from_system(
"You are a helpful assistant giving out valuable information to tourists in {{language}}."
)
messages = [system_message, ChatMessage.from_user("Tell me about {{location}}")]
res = pipe.run(
data={
"prompt_builder": {
"template_variables": {"location": location, "language": "German"},
"prompt_source": messages,
}
}
)
assert res == {
"prompt_builder": {
"prompt": [
ChatMessage.from_system(
"You are a helpful assistant giving out valuable information to tourists in German."
),
ChatMessage.from_user("Tell me about Berlin"),
]
}
}
messages = [
system_message,
ChatMessage.from_user("What's the weather forecast for {{location}} in the next {{day_count}} days?"),
]
res = pipe.run(
data={
"prompt_builder": {
"template_variables": {"location": location, "day_count": "5", "language": "English"},
"prompt_source": messages,
}
}
)
assert res == {
"prompt_builder": {
"prompt": [
ChatMessage.from_system(
"You are a helpful assistant giving out valuable information to tourists in English."
),
ChatMessage.from_user("What's the weather forecast for Berlin in the next 5 days?"),
]
}
}
def test_pipeline_complex(self):
@component
class ValueProducer:
def __init__(self, value_to_produce: str):
self.value_to_produce = value_to_produce
@component.output_types(value_output=str)
def run(self):
return {"value_output": self.value_to_produce}
pipe = Pipeline()
pipe.add_component("prompt_builder", DynamicChatPromptBuilder(runtime_variables=["value_output"]))
pipe.add_component("value_producer", ValueProducer(value_to_produce="Berlin"))
pipe.connect("value_producer.value_output", "prompt_builder")
messages = [
ChatMessage.from_system("You give valuable information to tourists."),
ChatMessage.from_user("Tell me about {{value_output}}"),
]
res = pipe.run(data={"prompt_source": messages})
assert res == {
"prompt_builder": {
"prompt": [
ChatMessage.from_system("You give valuable information to tourists."),
ChatMessage.from_user("Tell me about Berlin"),
]
}
}