diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index 44c857b40..a1c47a508 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Set from jinja2 import Template, meta -from haystack import component, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses.chat_message import ChatMessage, ChatRole logger = logging.getLogger(__name__) @@ -223,3 +223,36 @@ class ChatPromptBuilder: f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. " f"Required variables: {self.required_variables}. Provided variables: {provided_variables}." ) + + def to_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the component. + + :returns: + Serialized dictionary representation of the component. + """ + if self.template is not None: + template = [m.to_dict() for m in self.template] + else: + template = None + + return default_to_dict( + self, template=template, variables=self._variables, required_variables=self._required_variables + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatPromptBuilder": + """ + Deserialize this component from a dictionary. + + :param data: + The dictionary to deserialize and create the component. + + :returns: + The deserialized component. + """ + init_parameters = data["init_parameters"] + template = init_parameters.get("template", []) + init_parameters["template"] = [ChatMessage.from_dict(d) for d in template] + + return default_from_dict(cls, data) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 336f5a654..cc41d5f4c 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Dict, Optional @@ -99,3 +99,29 @@ class ChatMessage: :returns: A new ChatMessage instance. """ return cls(content, ChatRole.FUNCTION, name) + + def to_dict(self) -> Dict[str, Any]: + """ + Converts ChatMessage into a dictionary. + + :returns: + Serialized version of the object. + """ + data = asdict(self) + data["role"] = self.role.value + + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": + """ + Creates a new ChatMessage object from a dictionary. + + :param data: + The dictionary to build the ChatMessage object. + :returns: + The created object. + """ + data["role"] = ChatRole(data["role"]) + + return cls(**data) diff --git a/releasenotes/notes/fix-chat-prompt-builder-serialization-345605337ca2584e.yaml b/releasenotes/notes/fix-chat-prompt-builder-serialization-345605337ca2584e.yaml new file mode 100644 index 000000000..b9df38134 --- /dev/null +++ b/releasenotes/notes/fix-chat-prompt-builder-serialization-345605337ca2584e.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Solve serialization bug on 'ChatPromptBuilder' by creating 'to_dict' and 'from_dict' methods on 'ChatMessage' and 'ChatPromptBuilder'. diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index 406e40b73..8bc46edf1 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -494,3 +494,44 @@ class TestChatPromptBuilderDynamic: ] } } + + def test_to_dict(self): + component = ChatPromptBuilder( + template=[ChatMessage.from_user("text and {var}"), ChatMessage.from_assistant("content {required_var}")], + variables=["var", "required_var"], + required_variables=["required_var"], + ) + + assert component.to_dict() == { + "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", + "init_parameters": { + "template": [ + {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, + {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + ], + "variables": ["var", "required_var"], + "required_variables": ["required_var"], + }, + } + + def test_from_dict(self): + component = ChatPromptBuilder.from_dict( + data={ + "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", + "init_parameters": { + "template": [ + {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, + {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + ], + "variables": ["var", "required_var"], + "required_variables": ["required_var"], + }, + } + ) + + assert component.template == [ + ChatMessage.from_user("text and {var}"), + ChatMessage.from_assistant("content {required_var}"), + ] + assert component._variables == ["var", "required_var"] + assert component._required_variables == ["required_var"] diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index ba3b40533..ab4606d7b 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -81,3 +81,22 @@ def test_apply_custom_chat_templating_on_chat_message(): formatted_messages, chat_template=anthropic_template, tokenize=False ) assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:" + + +def test_to_dict(): + message = ChatMessage.from_user("content") + message.meta["some"] = "some" + + assert message.to_dict() == {"content": "content", "role": "user", "name": None, "meta": {"some": "some"}} + + +def test_from_dict(): + assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage( + content="text", role=ChatRole("user"), name=None, meta={} + ) + + +def test_from_dict_with_meta(): + assert ChatMessage.from_dict( + data={"content": "text", "role": "user", "name": None, "meta": {"something": "something"}} + ) == ChatMessage(content="text", role=ChatRole("user"), name=None, meta={"something": "something"})