feat: ChatPromptBuilder copies entire ChatMessage rather than copying content field only (#8317)

* Initial implementation of ChatMessage copy and deepcopy

* Add reno release note

* Satisfy hawkeye

* Remove copy and deepcopy, no need to complicate things

* Add new reno note

* Add unit test
This commit is contained in:
Vladimir Blagojevic 2024-09-02 17:06:38 +01:00 committed by GitHub
parent 9c1ad8e8ea
commit b2c19a8c7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 6 deletions

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Any, Dict, List, Optional, Set
from jinja2 import meta
@ -194,11 +195,9 @@ class ChatPromptBuilder:
compiled_template = self._env.from_string(message.content)
rendered_content = compiled_template.render(template_variables_combined)
rendered_message = (
ChatMessage.from_user(rendered_content)
if message.is_from(ChatRole.USER)
else ChatMessage.from_system(rendered_content)
)
# deep copy the message to avoid modifying the original message
rendered_message: ChatMessage = deepcopy(message)
rendered_message.content = rendered_content
processed_messages.append(rendered_message)
else:
processed_messages.append(message)

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Adapts how ChatPromptBuilder creates ChatMessages. Messages are deep copied to ensure all meta fields are copied correctly.

View File

@ -5,7 +5,7 @@ import pytest
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
from haystack import component
from haystack.core.pipeline.pipeline import Pipeline
from haystack.dataclasses.chat_message import ChatMessage
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.document import Document
@ -194,6 +194,17 @@ class TestChatPromptBuilder:
assert builder.run(template, name="John", var1="Big") == expected_result
def test_run_with_meta(self):
"""
Test that the ChatPromptBuilder correctly handles meta data.
It should render the message and copy the meta data from the original message.
"""
m = ChatMessage(content="This is a {{ variable }}", role=ChatRole.USER, name=None, meta={"test": "test"})
builder = ChatPromptBuilder(template=[m])
res = builder.run(variable="test")
res_msg = ChatMessage(content="This is a test", role=ChatRole.USER, name=None, meta={"test": "test"})
assert res == {"prompt": [res_msg]}
def test_run_with_invalid_template(self):
builder = ChatPromptBuilder()