From dcc7e63dc994885d7655d6e320f89a3ad336147e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 23 Oct 2023 16:08:05 +0200 Subject: [PATCH] feat: Add ChatMessage class to Haystack 2.0 (#6144) * Add ChatMessage and ChatRole --- haystack/preview/dataclasses/__init__.py | 4 +- haystack/preview/dataclasses/chat_message.py | 79 +++++++++++++++++++ .../add-chat-message-c456e4603529ae85.yaml | 5 ++ test/preview/dataclasses/test_chat_message.py | 47 +++++++++++ 4 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 haystack/preview/dataclasses/chat_message.py create mode 100644 releasenotes/notes/add-chat-message-c456e4603529ae85.yaml create mode 100644 test/preview/dataclasses/test_chat_message.py diff --git a/haystack/preview/dataclasses/__init__.py b/haystack/preview/dataclasses/__init__.py index 6873ac0cc..8cf010500 100644 --- a/haystack/preview/dataclasses/__init__.py +++ b/haystack/preview/dataclasses/__init__.py @@ -1,5 +1,7 @@ from haystack.preview.dataclasses.document import Document from haystack.preview.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, Answer from haystack.preview.dataclasses.byte_stream import ByteStream +from haystack.preview.dataclasses.chat_message import ChatMessage +from haystack.preview.dataclasses.chat_message import ChatRole -__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream"] +__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream", "ChatMessage", "ChatRole"] diff --git a/haystack/preview/dataclasses/chat_message.py b/haystack/preview/dataclasses/chat_message.py new file mode 100644 index 000000000..cff420c63 --- /dev/null +++ b/haystack/preview/dataclasses/chat_message.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Any, Optional + + +class ChatRole(str, Enum): + """Enumeration representing the roles within a chat.""" + + ASSISTANT = "assistant" + USER = "user" + SYSTEM = "system" + FUNCTION = "function" + + +@dataclass +class ChatMessage: + """ + Represents a message in a LLM chat conversation. + + :param content: The text content of the message. + :param role: The role of the entity sending the message. + :param name: The name of the function being called (only applicable for role FUNCTION). + :param metadata: Additional metadata associated with the message. + """ + + content: str + role: ChatRole + name: Optional[str] + metadata: Dict[str, Any] = field(default_factory=dict, hash=False) + + def is_from(self, role: ChatRole) -> bool: + """ + Check if the message is from a specific role. + + :param role: The role to check against. + :return: True if the message is from the specified role, False otherwise. + """ + return self.role == role + + @classmethod + def from_assistant(cls, content: str) -> "ChatMessage": + """ + Create a message from the assistant. + + :param content: The text content of the message. + :return: A new ChatMessage instance. + """ + return cls(content, ChatRole.ASSISTANT, None) + + @classmethod + def from_user(cls, content: str) -> "ChatMessage": + """ + Create a message from the user. + + :param content: The text content of the message. + :return: A new ChatMessage instance. + """ + return cls(content, ChatRole.USER, None) + + @classmethod + def from_system(cls, content: str) -> "ChatMessage": + """ + Create a message from the system. + + :param content: The text content of the message. + :return: A new ChatMessage instance. + """ + return cls(content, ChatRole.SYSTEM, None) + + @classmethod + def from_function(cls, content: str, name: str) -> "ChatMessage": + """ + Create a message from a function call. + + :param content: The text content of the message. + :param name: The name of the function being called. + :return: A new ChatMessage instance. + """ + return cls(content, ChatRole.FUNCTION, name) diff --git a/releasenotes/notes/add-chat-message-c456e4603529ae85.yaml b/releasenotes/notes/add-chat-message-c456e4603529ae85.yaml new file mode 100644 index 000000000..3d1571285 --- /dev/null +++ b/releasenotes/notes/add-chat-message-c456e4603529ae85.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Introduce ChatMessage data class to facilitate structured handling and processing of message content + within LLM chat interactions. diff --git a/test/preview/dataclasses/test_chat_message.py b/test/preview/dataclasses/test_chat_message.py new file mode 100644 index 000000000..285d38453 --- /dev/null +++ b/test/preview/dataclasses/test_chat_message.py @@ -0,0 +1,47 @@ +import pytest + +from haystack.preview.dataclasses import ChatMessage, ChatRole + + +@pytest.mark.unit +def test_from_assistant_with_valid_content(): + content = "Hello, how can I assist you?" + message = ChatMessage.from_assistant(content) + assert message.content == content + assert message.role == ChatRole.ASSISTANT + + +@pytest.mark.unit +def test_from_user_with_valid_content(): + content = "I have a question." + message = ChatMessage.from_user(content) + assert message.content == content + assert message.role == ChatRole.USER + + +@pytest.mark.unit +def test_from_system_with_valid_content(): + content = "System message." + message = ChatMessage.from_system(content) + assert message.content == content + assert message.role == ChatRole.SYSTEM + + +@pytest.mark.unit +def test_with_empty_content(): + message = ChatMessage("", ChatRole.USER, None) + assert message.content == "" + + +@pytest.mark.unit +def test_with_invalid_role(): + with pytest.raises(TypeError): + ChatMessage("Invalid role", "invalid_role") + + +@pytest.mark.unit +def test_from_function_with_empty_name(): + content = "Function call" + message = ChatMessage.from_function(content, "") + assert message.content == content + assert message.name == ""