feat!: new ChatMessage (#8640)

* draft

* del HF token in tests

* adaptations

* progress

* fix type

* import sorting

* more control on deserialization

* release note

* improvements

* support name field

* fix chatpromptbuilder test

* Update chat_message.py

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Stefano Fiorucci 2024-12-17 17:02:04 +01:00 committed by GitHub
parent a5b57f4b1f
commit ea3602643a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 568 additions and 165 deletions

View File

@ -9,7 +9,7 @@ from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
logger = logging.getLogger(__name__)
@ -197,10 +197,10 @@ class ChatPromptBuilder:
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
compiled_template = self._env.from_string(message.text)
rendered_content = compiled_template.render(template_variables_combined)
rendered_text = compiled_template.render(template_variables_combined)
# deep copy the message to avoid modifying the original message
rendered_message: ChatMessage = deepcopy(message)
rendered_message.content = rendered_content
rendered_message._content = [TextContent(text=rendered_text)]
processed_messages.append(rendered_message)
else:
processed_messages.append(message)

View File

@ -25,13 +25,8 @@ def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
:returns: A dictionary with the following keys:
- `role`
- `content`
- `name` (optional)
"""
formatted_msg = {"role": message.role.value, "content": message.content}
if message.name:
formatted_msg["name"] = message.name
return formatted_msg
return {"role": message.role.value, "content": message.text or ""}
@component

View File

@ -13,16 +13,11 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]:
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
:returns: A dictionary with the following key:
:returns: A dictionary with the following keys:
- `role`
- `content`
- `name` (optional)
"""
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
openai_msg = {"role": message.role.value, "content": message.text}
if message.name:
openai_msg["name"] = message.name
return openai_msg
return {"role": message.role.value, "content": message.text}

View File

@ -4,7 +4,7 @@
from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult
from haystack.dataclasses.document import Document
from haystack.dataclasses.sparse_embedding import SparseEmbedding
from haystack.dataclasses.streaming_chunk import StreamingChunk
@ -17,6 +17,9 @@ __all__ = [
"ByteStream",
"ChatMessage",
"ChatRole",
"ToolCall",
"ToolCallResult",
"TextContent",
"StreamingChunk",
"SparseEmbedding",
]

View File

@ -5,104 +5,318 @@
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Sequence, Union
LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"}
class ChatRole(str, Enum):
"""Enumeration representing the roles within a chat."""
"""
Enumeration representing the roles within a chat.
"""
ASSISTANT = "assistant"
#: The user role. A message from the user contains only text.
USER = "user"
#: The system role. A message from the system contains only text.
SYSTEM = "system"
#: The assistant role. A message from the assistant can contain text and Tool calls. It can also store metadata.
ASSISTANT = "assistant"
#: The tool role. A message from a tool contains the result of a Tool invocation.
TOOL = "tool"
#: The function role. Deprecated in favor of `TOOL`.
FUNCTION = "function"
@staticmethod
def from_str(string: str) -> "ChatRole":
"""
Convert a string to a ChatRole enum.
"""
enum_map = {e.value: e for e in ChatRole}
role = enum_map.get(string)
if role is None:
msg = f"Unknown chat role '{string}'. Supported roles are: {list(enum_map.keys())}"
raise ValueError(msg)
return role
@dataclass
class ToolCall:
"""
Represents a Tool call prepared by the model, usually contained in an assistant message.
:param id: The ID of the Tool call.
:param tool_name: The name of the Tool to call.
:param arguments: The arguments to call the Tool with.
"""
tool_name: str
arguments: Dict[str, Any]
id: Optional[str] = None # noqa: A003
@dataclass
class ToolCallResult:
"""
Represents the result of a Tool invocation.
:param result: The result of the Tool invocation.
:param origin: The Tool call that produced this result.
:param error: Whether the Tool invocation resulted in an error.
"""
result: str
origin: ToolCall
error: bool
@dataclass
class TextContent:
"""
The textual content of a chat message.
:param text: The text content of the message.
"""
text: str
ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult]
@dataclass
class ChatMessage:
"""
Represents a message in a LLM chat conversation.
:param content: The text content of the message.
:param role: The role of the entity sending the message.
:param name: The name of the function being called (only applicable for role FUNCTION).
:param meta: Additional metadata associated with the message.
Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a ChatMessage.
"""
content: str
role: ChatRole
name: Optional[str]
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
_role: ChatRole
_content: Sequence[ChatMessageContentT]
_name: Optional[str] = None
_meta: Dict[str, Any] = field(default_factory=dict, hash=False)
def __new__(cls, *args, **kwargs):
"""
This method is reimplemented to make the changes to the `ChatMessage` dataclass more visible.
"""
general_msg = (
"Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a "
"ChatMessage. For more information about the new API and how to migrate, see the documentation:"
" https://docs.haystack.deepset.ai/docs/data-classes#chatmessage"
)
if any(param in kwargs for param in LEGACY_INIT_PARAMETERS):
raise TypeError(
"The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. "
f"{general_msg}"
)
allowed_content_types = (TextContent, ToolCall, ToolCallResult)
if len(args) > 1 and not isinstance(args[1], allowed_content_types):
raise TypeError(
"The `_content` parameter of `ChatMessage` must be one of the following types: "
f"{', '.join(t.__name__ for t in allowed_content_types)}. "
f"{general_msg}"
)
return super(ChatMessage, cls).__new__(cls)
def __post_init__(self):
if self._role == ChatRole.FUNCTION:
msg = "The `FUNCTION` role has been deprecated in favor of `TOOL` and will be removed in 2.10.0. "
warnings.warn(msg, DeprecationWarning)
def __getattribute__(self, name):
"""
This method is reimplemented to make the `content` attribute removal more visible.
"""
if name == "content":
msg = (
"The `content` attribute of `ChatMessage` has been removed. "
"Use the `text` property to access the textual value. "
"For more information about the new API and how to migrate, see the documentation: "
"https://docs.haystack.deepset.ai/docs/data-classes#chatmessage"
)
raise AttributeError(msg)
return object.__getattribute__(self, name)
def __len__(self):
return len(self._content)
@property
def role(self) -> ChatRole:
"""
Returns the role of the entity sending the message.
"""
return self._role
@property
def meta(self) -> Dict[str, Any]:
"""
Returns the metadata associated with the message.
"""
return self._meta
@property
def name(self) -> Optional[str]:
"""
Returns the name associated with the message.
"""
return self._name
@property
def texts(self) -> List[str]:
"""
Returns the list of all texts contained in the message.
"""
return [content.text for content in self._content if isinstance(content, TextContent)]
@property
def text(self) -> Optional[str]:
"""
Returns the textual content of the message.
Returns the first text contained in the message.
"""
# Currently, this property mirrors the `content` attribute. This will change in 2.9.0.
# The current actual return type is str. We are using Optional[str] to be ready for 2.9.0,
# when None will be a valid value for `text`.
return object.__getattribute__(self, "content")
if texts := self.texts:
return texts[0]
return None
def __getattribute__(self, name):
# this method is reimplemented to warn about the deprecation of the `content` attribute
if name == "content":
msg = (
"The `content` attribute of `ChatMessage` will be removed in Haystack 2.9.0. "
"Use the `text` property to access the textual value."
)
warnings.warn(msg, DeprecationWarning)
return object.__getattribute__(self, name)
@property
def tool_calls(self) -> List[ToolCall]:
"""
Returns the list of all Tool calls contained in the message.
"""
return [content for content in self._content if isinstance(content, ToolCall)]
def is_from(self, role: ChatRole) -> bool:
@property
def tool_call(self) -> Optional[ToolCall]:
"""
Returns the first Tool call contained in the message.
"""
if tool_calls := self.tool_calls:
return tool_calls[0]
return None
@property
def tool_call_results(self) -> List[ToolCallResult]:
"""
Returns the list of all Tool call results contained in the message.
"""
return [content for content in self._content if isinstance(content, ToolCallResult)]
@property
def tool_call_result(self) -> Optional[ToolCallResult]:
"""
Returns the first Tool call result contained in the message.
"""
if tool_call_results := self.tool_call_results:
return tool_call_results[0]
return None
def is_from(self, role: Union[ChatRole, str]) -> bool:
"""
Check if the message is from a specific role.
:param role: The role to check against.
:returns: True if the message is from the specified role, False otherwise.
"""
return self.role == role
if isinstance(role, str):
role = ChatRole.from_str(role)
return self._role == role
@classmethod
def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage":
"""
Create a message from the assistant.
:param content: The text content of the message.
:param meta: Additional metadata associated with the message.
:returns: A new ChatMessage instance.
"""
return cls(content, ChatRole.ASSISTANT, None, meta or {})
@classmethod
def from_user(cls, content: str) -> "ChatMessage":
def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage":
"""
Create a message from the user.
:param content: The text content of the message.
:param text: The text content of the message.
:param meta: Additional metadata associated with the message.
:param name: An optional name for the participant. This field is only supported by OpenAI.
:returns: A new ChatMessage instance.
"""
return cls(content, ChatRole.USER, None)
return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}, _name=name)
@classmethod
def from_system(cls, content: str) -> "ChatMessage":
def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage":
"""
Create a message from the system.
:param content: The text content of the message.
:param text: The text content of the message.
:param meta: Additional metadata associated with the message.
:param name: An optional name for the participant. This field is only supported by OpenAI.
:returns: A new ChatMessage instance.
"""
return cls(content, ChatRole.SYSTEM, None)
return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}, _name=name)
@classmethod
def from_assistant(
cls,
text: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
tool_calls: Optional[List[ToolCall]] = None,
) -> "ChatMessage":
"""
Create a message from the assistant.
:param text: The text content of the message.
:param meta: Additional metadata associated with the message.
:param tool_calls: The Tool calls to include in the message.
:param name: An optional name for the participant. This field is only supported by OpenAI.
:returns: A new ChatMessage instance.
"""
content: List[ChatMessageContentT] = []
if text is not None:
content.append(TextContent(text=text))
if tool_calls:
content.extend(tool_calls)
return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}, _name=name)
@classmethod
def from_tool(
cls, tool_result: str, origin: ToolCall, error: bool = False, meta: Optional[Dict[str, Any]] = None
) -> "ChatMessage":
"""
Create a message from a Tool.
:param tool_result: The result of the Tool invocation.
:param origin: The Tool call that produced this result.
:param error: Whether the Tool invocation resulted in an error.
:param meta: Additional metadata associated with the message.
:returns: A new ChatMessage instance.
"""
return cls(
_role=ChatRole.TOOL,
_content=[ToolCallResult(result=tool_result, origin=origin, error=error)],
_meta=meta or {},
)
@classmethod
def from_function(cls, content: str, name: str) -> "ChatMessage":
"""
Create a message from a function call.
Create a message from a function call. Deprecated in favor of `from_tool`.
:param content: The text content of the message.
:param name: The name of the function being called.
:returns: A new ChatMessage instance.
"""
return cls(content, ChatRole.FUNCTION, name)
msg = (
"The `from_function` method is deprecated and will be removed in version 2.10.0. "
"Its behavior has changed: it now attempts to convert legacy function messages to tool messages. "
"This conversion is not guaranteed to succeed in all scenarios. "
"Please migrate to `ChatMessage.from_tool` and carefully verify the results if you "
"continue to use this method."
)
warnings.warn(msg)
return cls.from_tool(content, ToolCall(id=None, tool_name=name, arguments={}), error=False)
def to_dict(self) -> Dict[str, Any]:
"""
@ -111,10 +325,23 @@ class ChatMessage:
:returns:
Serialized version of the object.
"""
data = asdict(self)
data["role"] = self.role.value
serialized: Dict[str, Any] = {}
serialized["_role"] = self._role.value
serialized["_meta"] = self._meta
serialized["_name"] = self._name
content: List[Dict[str, Any]] = []
for part in self._content:
if isinstance(part, TextContent):
content.append({"text": part.text})
elif isinstance(part, ToolCall):
content.append({"tool_call": asdict(part)})
elif isinstance(part, ToolCallResult):
content.append({"tool_call_result": asdict(part)})
else:
raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.")
return data
serialized["_content"] = content
return serialized
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage":
@ -126,6 +353,31 @@ class ChatMessage:
:returns:
The created object.
"""
data["role"] = ChatRole(data["role"])
if any(param in data for param in LEGACY_INIT_PARAMETERS):
raise TypeError(
"The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. "
"For more information about the new API and how to migrate, see the documentation: "
"https://docs.haystack.deepset.ai/docs/data-classes#chatmessage"
)
data["_role"] = ChatRole(data["_role"])
content: List[ChatMessageContentT] = []
for part in data["_content"]:
if "text" in part:
content.append(TextContent(text=part["text"]))
elif "tool_call" in part:
content.append(ToolCall(**part["tool_call"]))
elif "tool_call_result" in part:
result = part["tool_call_result"]["result"]
origin = ToolCall(**part["tool_call_result"]["origin"])
error = part["tool_call_result"]["error"]
tcr = ToolCallResult(result=result, origin=origin, error=error)
content.append(tcr)
else:
raise ValueError(f"Unsupported content in serialized ChatMessage: `{part}`")
data["_content"] = content
return cls(**data)

View File

@ -0,0 +1,23 @@
---
highlights: >
We are introducing a refactored ChatMessage dataclass. It is more flexible, future-proof, and compatible with
different types of content: text, tool calls, tool calls results.
For information about the new API and how to migrate, see the documentation:
https://docs.haystack.deepset.ai/docs/data-classes#chatmessage
upgrade:
- |
The refactoring of the ChatMessage dataclass includes some breaking changes, involving ChatMessage creation and
accessing attributes. If you have a Pipeline containing a ChatPromptBuilder, serialized using Haystack<2.9.0,
deserialization may break.
For detailed information about the changes and how to migrate, see the documentation:
https://docs.haystack.deepset.ai/docs/data-classes#chatmessage
features:
- |
Changed the ChatMessage dataclass to support different types of content, including tool calls, and tool call
results.
deprecations:
- |
The function role and ChatMessage.from_function class method have been deprecated and will be removed in
Haystack 2.10.0. ChatMessage.from_function also attempts to produce a valid tool message.
For more information, see the documentation:
https://docs.haystack.deepset.ai/docs/data-classes#chatmessage

View File

@ -13,8 +13,8 @@ class TestChatPromptBuilder:
def test_init(self):
builder = ChatPromptBuilder(
template=[
ChatMessage.from_user(content="This is a {{ variable }}"),
ChatMessage.from_system(content="This is a {{ variable2 }}"),
ChatMessage.from_user("This is a {{ variable }}"),
ChatMessage.from_system("This is a {{ variable2 }}"),
]
)
assert builder.required_variables == []
@ -531,8 +531,13 @@ class TestChatPromptBuilderDynamic:
"type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder",
"init_parameters": {
"template": [
{"content": "text and {var}", "role": "user", "name": None, "meta": {}},
{"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}},
{"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None},
{
"_content": [{"text": "content {required_var}"}],
"_role": "assistant",
"_meta": {},
"_name": None,
},
],
"variables": ["var", "required_var"],
"required_variables": ["required_var"],
@ -545,8 +550,13 @@ class TestChatPromptBuilderDynamic:
"type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder",
"init_parameters": {
"template": [
{"content": "text and {var}", "role": "user", "name": None, "meta": {}},
{"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}},
{"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None},
{
"_content": [{"text": "content {required_var}"}],
"_role": "assistant",
"_meta": {},
"_name": None,
},
],
"variables": ["var", "required_var"],
"required_variables": ["required_var"],

View File

@ -68,13 +68,6 @@ def test_convert_message_to_hfapi_format():
message = ChatMessage.from_user("I have a question")
assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"}
message = ChatMessage.from_function("Function call", "function_name")
assert _convert_message_to_hfapi_format(message) == {
"role": "function",
"content": "Function call",
"name": "function_name",
}
class TestHuggingFaceAPIGenerator:
def test_init_invalid_api_type(self):

View File

@ -14,10 +14,3 @@ def test_convert_message_to_openai_format():
message = ChatMessage.from_user("I have a question")
assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"}
message = ChatMessage.from_function("Function call", "function_name")
assert _convert_message_to_openai_format(message) == {
"role": "function",
"content": "Function call",
"name": "function_name",
}

View File

@ -349,7 +349,7 @@ class TestRouter:
]
router = ConditionalRouter(routes, unsafe=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
message = ChatMessage.from_user("This is a message")
res = router.run(streams=streams, message=message)
assert res == {"message": message}
@ -370,7 +370,7 @@ class TestRouter:
]
router = ConditionalRouter(routes, validate_output_type=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
message = ChatMessage.from_user("This is a message")
with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"):
router.run(streams=streams, message=message)
@ -391,7 +391,7 @@ class TestRouter:
]
router = ConditionalRouter(routes, unsafe=True, validate_output_type=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
message = ChatMessage.from_user("This is a message")
res = router.run(streams=streams, message=message)
assert isinstance(res["message"], ChatMessage)

View File

@ -1657,7 +1657,7 @@ def that_is_a_simple_agent():
class ToolExtractor:
@component.output_types(output=List[str])
def run(self, messages: List[ChatMessage]):
prompt: str = messages[-1].content
prompt: str = messages[-1].text
lines = prompt.strip().split("\n")
for line in reversed(lines):
pattern = r"Action:\s*(\w+)\[(.*?)\]"
@ -1678,14 +1678,14 @@ def that_is_a_simple_agent():
@component.output_types(output=List[ChatMessage])
def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]):
content = current_prompt[-1].content + replies[-1].content + self._suffix
content = current_prompt[-1].text + replies[-1].text + self._suffix
return {"output": [ChatMessage.from_user(content)]}
@component
class SearchOutputAdapter:
@component.output_types(output=List[ChatMessage])
def run(self, replies: List[ChatMessage]):
content = f"Observation: {replies[-1].content}\n"
content = f"Observation: {replies[-1].text}\n"
return {"output": [ChatMessage.from_assistant(content)]}
pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator())

View File

@ -4,64 +4,240 @@
import pytest
from transformers import AutoTokenizer
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent
from haystack.components.generators.openai_utils import _convert_message_to_openai_format
def test_tool_call_init():
tc = ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
assert tc.id == "123"
assert tc.tool_name == "mytool"
assert tc.arguments == {"a": 1}
def test_tool_call_result_init():
tcr = ToolCallResult(result="result", origin=ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), error=True)
assert tcr.result == "result"
assert tcr.origin == ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
assert tcr.error
def test_text_content_init():
tc = TextContent(text="Hello")
assert tc.text == "Hello"
def test_from_assistant_with_valid_content():
content = "Hello, how can I assist you?"
message = ChatMessage.from_assistant(content)
assert message.content == content
assert message.text == content
text = "Hello, how can I assist you?"
message = ChatMessage.from_assistant(text)
assert message.role == ChatRole.ASSISTANT
assert message._content == [TextContent(text)]
assert message.name is None
assert message.text == text
assert message.texts == [text]
assert not message.tool_calls
assert not message.tool_call
assert not message.tool_call_results
assert not message.tool_call_result
def test_from_assistant_with_tool_calls():
tool_calls = [
ToolCall(id="123", tool_name="mytool", arguments={"a": 1}),
ToolCall(id="456", tool_name="mytool2", arguments={"b": 2}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
assert message.role == ChatRole.ASSISTANT
assert message._content == tool_calls
assert message.tool_calls == tool_calls
assert message.tool_call == tool_calls[0]
assert not message.texts
assert not message.text
assert not message.tool_call_results
assert not message.tool_call_result
def test_from_user_with_valid_content():
content = "I have a question."
message = ChatMessage.from_user(content)
assert message.content == content
assert message.text == content
text = "I have a question."
message = ChatMessage.from_user(text=text)
assert message.role == ChatRole.USER
assert message._content == [TextContent(text)]
assert message.name is None
assert message.text == text
assert message.texts == [text]
assert not message.tool_calls
assert not message.tool_call
assert not message.tool_call_results
assert not message.tool_call_result
def test_from_user_with_name():
text = "I have a question."
message = ChatMessage.from_user(text=text, name="John")
assert message.name == "John"
assert message.role == ChatRole.USER
assert message._content == [TextContent(text)]
def test_from_system_with_valid_content():
content = "System message."
message = ChatMessage.from_system(content)
assert message.content == content
assert message.text == content
text = "I have a question."
message = ChatMessage.from_system(text=text)
assert message.role == ChatRole.SYSTEM
assert message._content == [TextContent(text)]
assert message.text == text
assert message.texts == [text]
assert not message.tool_calls
assert not message.tool_call
assert not message.tool_call_results
assert not message.tool_call_result
def test_with_empty_content():
message = ChatMessage.from_user("")
assert message.content == ""
assert message.text == ""
assert message.role == ChatRole.USER
def test_from_tool_with_valid_content():
tool_result = "Tool result"
origin = ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
message = ChatMessage.from_tool(tool_result, origin, error=False)
tcr = ToolCallResult(result=tool_result, origin=origin, error=False)
assert message._content == [tcr]
assert message.role == ChatRole.TOOL
assert message.tool_call_result == tcr
assert message.tool_call_results == [tcr]
assert not message.tool_calls
assert not message.tool_call
assert not message.texts
assert not message.text
def test_from_function_with_empty_name():
content = "Function call"
message = ChatMessage.from_function(content, "")
assert message.content == content
assert message.text == content
assert message.name == ""
assert message.role == ChatRole.FUNCTION
def test_multiple_text_segments():
texts = [TextContent(text="Hello"), TextContent(text="World")]
message = ChatMessage(_role=ChatRole.USER, _content=texts)
assert message.texts == ["Hello", "World"]
assert len(message) == 2
def test_to_openai_format():
message = ChatMessage.from_system("You are good assistant")
assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"}
def test_mixed_content():
content = [TextContent(text="Hello"), ToolCall(id="123", tool_name="mytool", arguments={"a": 1})]
message = ChatMessage.from_user("I have a question")
assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"}
message = ChatMessage(_role=ChatRole.ASSISTANT, _content=content)
message = ChatMessage.from_function("Function call", "function_name")
assert _convert_message_to_openai_format(message) == {
"role": "function",
"content": "Function call",
"name": "function_name",
assert len(message) == 2
assert message.texts == ["Hello"]
assert message.text == "Hello"
assert message.tool_calls == [content[1]]
assert message.tool_call == content[1]
def test_from_function():
# check warning is raised
with pytest.warns():
message = ChatMessage.from_function("Result of function invocation", "my_function")
assert message.role == ChatRole.TOOL
assert message.tool_call_result == ToolCallResult(
result="Result of function invocation",
origin=ToolCall(id=None, tool_name="my_function", arguments={}),
error=False,
)
def test_serde():
# the following message is created just for testing purposes and does not make sense in a real use case
role = ChatRole.ASSISTANT
text_content = TextContent(text="Hello")
tool_call = ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
tool_call_result = ToolCallResult(result="result", origin=tool_call, error=False)
meta = {"some": "info"}
message = ChatMessage(_role=role, _content=[text_content, tool_call, tool_call_result], _meta=meta)
serialized_message = message.to_dict()
assert serialized_message == {
"_content": [
{"text": "Hello"},
{"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}},
{
"tool_call_result": {
"result": "result",
"error": False,
"origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}},
}
},
],
"_role": "assistant",
"_name": None,
"_meta": {"some": "info"},
}
deserialized_message = ChatMessage.from_dict(serialized_message)
assert deserialized_message == message
def test_to_dict_with_invalid_content_type():
text_content = TextContent(text="Hello")
invalid_content = "invalid"
message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[text_content, invalid_content])
with pytest.raises(TypeError):
message.to_dict()
def test_from_dict_with_invalid_content_type():
data = {"_role": "assistant", "_content": [{"text": "Hello"}, "invalid"]}
with pytest.raises(ValueError):
ChatMessage.from_dict(data)
data = {"_role": "assistant", "_content": [{"text": "Hello"}, {"invalid": "invalid"}]}
with pytest.raises(ValueError):
ChatMessage.from_dict(data)
def test_from_dict_with_legacy_init_parameters():
with pytest.raises(TypeError):
ChatMessage.from_dict({"role": "user", "content": "This is a message"})
def test_chat_message_content_attribute_removed():
message = ChatMessage.from_user(text="This is a message")
with pytest.raises(AttributeError):
message.content
def test_chat_message_init_parameters_removed():
with pytest.raises(TypeError):
ChatMessage(role="irrelevant", content="This is a message")
def test_chat_message_init_content_parameter_type():
with pytest.raises(TypeError):
ChatMessage(ChatRole.USER, "This is a message")
def test_chat_message_function_role_deprecated():
with pytest.warns(DeprecationWarning):
ChatMessage(ChatRole.FUNCTION, TextContent("This is a message"))
@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
@ -93,40 +269,3 @@ def test_apply_custom_chat_templating_on_chat_message():
formatted_messages, chat_template=anthropic_template, tokenize=False
)
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"
def test_to_dict():
content = "content"
role = "user"
meta = {"some": "some"}
message = ChatMessage.from_user(content)
message.meta.update(meta)
assert message.text == content
assert message.to_dict() == {"content": content, "role": role, "name": None, "meta": meta}
def test_from_dict():
assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage.from_user(
"text"
)
def test_from_dict_with_meta():
data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"})
def test_content_deprecation_warning(recwarn):
message = ChatMessage.from_user("my message")
# accessing the content attribute triggers the deprecation warning
_ = message.content
assert len(recwarn) == 1
wrn = recwarn.pop(DeprecationWarning)
assert "`content` attribute" in wrn.message.args[0]
# accessing the text property does not trigger a warning
assert message.text == "my message"
assert len(recwarn) == 0