mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-29 08:26:19 +00:00
feat!: new ChatMessage (#8640)
* draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * Update chat_message.py --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
a5b57f4b1f
commit
ea3602643a
@ -9,7 +9,7 @@ from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -197,10 +197,10 @@ class ChatPromptBuilder:
|
||||
if message.text is None:
|
||||
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
|
||||
compiled_template = self._env.from_string(message.text)
|
||||
rendered_content = compiled_template.render(template_variables_combined)
|
||||
rendered_text = compiled_template.render(template_variables_combined)
|
||||
# deep copy the message to avoid modifying the original message
|
||||
rendered_message: ChatMessage = deepcopy(message)
|
||||
rendered_message.content = rendered_content
|
||||
rendered_message._content = [TextContent(text=rendered_text)]
|
||||
processed_messages.append(rendered_message)
|
||||
else:
|
||||
processed_messages.append(message)
|
||||
|
||||
@ -25,13 +25,8 @@ def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
|
||||
:returns: A dictionary with the following keys:
|
||||
- `role`
|
||||
- `content`
|
||||
- `name` (optional)
|
||||
"""
|
||||
formatted_msg = {"role": message.role.value, "content": message.content}
|
||||
if message.name:
|
||||
formatted_msg["name"] = message.name
|
||||
|
||||
return formatted_msg
|
||||
return {"role": message.role.value, "content": message.text or ""}
|
||||
|
||||
|
||||
@component
|
||||
|
||||
@ -13,16 +13,11 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]:
|
||||
|
||||
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
|
||||
|
||||
:returns: A dictionary with the following key:
|
||||
:returns: A dictionary with the following keys:
|
||||
- `role`
|
||||
- `content`
|
||||
- `name` (optional)
|
||||
"""
|
||||
if message.text is None:
|
||||
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
|
||||
|
||||
openai_msg = {"role": message.role.value, "content": message.text}
|
||||
if message.name:
|
||||
openai_msg["name"] = message.name
|
||||
|
||||
return openai_msg
|
||||
return {"role": message.role.value, "content": message.text}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult
|
||||
from haystack.dataclasses.document import Document
|
||||
from haystack.dataclasses.sparse_embedding import SparseEmbedding
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
@ -17,6 +17,9 @@ __all__ = [
|
||||
"ByteStream",
|
||||
"ChatMessage",
|
||||
"ChatRole",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"TextContent",
|
||||
"StreamingChunk",
|
||||
"SparseEmbedding",
|
||||
]
|
||||
|
||||
@ -5,104 +5,318 @@
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"}
|
||||
|
||||
|
||||
class ChatRole(str, Enum):
|
||||
"""Enumeration representing the roles within a chat."""
|
||||
"""
|
||||
Enumeration representing the roles within a chat.
|
||||
"""
|
||||
|
||||
ASSISTANT = "assistant"
|
||||
#: The user role. A message from the user contains only text.
|
||||
USER = "user"
|
||||
|
||||
#: The system role. A message from the system contains only text.
|
||||
SYSTEM = "system"
|
||||
|
||||
#: The assistant role. A message from the assistant can contain text and Tool calls. It can also store metadata.
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
#: The tool role. A message from a tool contains the result of a Tool invocation.
|
||||
TOOL = "tool"
|
||||
|
||||
#: The function role. Deprecated in favor of `TOOL`.
|
||||
FUNCTION = "function"
|
||||
|
||||
@staticmethod
|
||||
def from_str(string: str) -> "ChatRole":
|
||||
"""
|
||||
Convert a string to a ChatRole enum.
|
||||
"""
|
||||
enum_map = {e.value: e for e in ChatRole}
|
||||
role = enum_map.get(string)
|
||||
if role is None:
|
||||
msg = f"Unknown chat role '{string}'. Supported roles are: {list(enum_map.keys())}"
|
||||
raise ValueError(msg)
|
||||
return role
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""
|
||||
Represents a Tool call prepared by the model, usually contained in an assistant message.
|
||||
|
||||
:param id: The ID of the Tool call.
|
||||
:param tool_name: The name of the Tool to call.
|
||||
:param arguments: The arguments to call the Tool with.
|
||||
"""
|
||||
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any]
|
||||
id: Optional[str] = None # noqa: A003
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""
|
||||
Represents the result of a Tool invocation.
|
||||
|
||||
:param result: The result of the Tool invocation.
|
||||
:param origin: The Tool call that produced this result.
|
||||
:param error: Whether the Tool invocation resulted in an error.
|
||||
"""
|
||||
|
||||
result: str
|
||||
origin: ToolCall
|
||||
error: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
"""
|
||||
The textual content of a chat message.
|
||||
|
||||
:param text: The text content of the message.
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""
|
||||
Represents a message in a LLM chat conversation.
|
||||
|
||||
:param content: The text content of the message.
|
||||
:param role: The role of the entity sending the message.
|
||||
:param name: The name of the function being called (only applicable for role FUNCTION).
|
||||
:param meta: Additional metadata associated with the message.
|
||||
Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a ChatMessage.
|
||||
"""
|
||||
|
||||
content: str
|
||||
role: ChatRole
|
||||
name: Optional[str]
|
||||
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
|
||||
_role: ChatRole
|
||||
_content: Sequence[ChatMessageContentT]
|
||||
_name: Optional[str] = None
|
||||
_meta: Dict[str, Any] = field(default_factory=dict, hash=False)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
This method is reimplemented to make the changes to the `ChatMessage` dataclass more visible.
|
||||
"""
|
||||
|
||||
general_msg = (
|
||||
"Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a "
|
||||
"ChatMessage. For more information about the new API and how to migrate, see the documentation:"
|
||||
" https://docs.haystack.deepset.ai/docs/data-classes#chatmessage"
|
||||
)
|
||||
|
||||
if any(param in kwargs for param in LEGACY_INIT_PARAMETERS):
|
||||
raise TypeError(
|
||||
"The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. "
|
||||
f"{general_msg}"
|
||||
)
|
||||
|
||||
allowed_content_types = (TextContent, ToolCall, ToolCallResult)
|
||||
if len(args) > 1 and not isinstance(args[1], allowed_content_types):
|
||||
raise TypeError(
|
||||
"The `_content` parameter of `ChatMessage` must be one of the following types: "
|
||||
f"{', '.join(t.__name__ for t in allowed_content_types)}. "
|
||||
f"{general_msg}"
|
||||
)
|
||||
|
||||
return super(ChatMessage, cls).__new__(cls)
|
||||
|
||||
def __post_init__(self):
|
||||
if self._role == ChatRole.FUNCTION:
|
||||
msg = "The `FUNCTION` role has been deprecated in favor of `TOOL` and will be removed in 2.10.0. "
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""
|
||||
This method is reimplemented to make the `content` attribute removal more visible.
|
||||
"""
|
||||
|
||||
if name == "content":
|
||||
msg = (
|
||||
"The `content` attribute of `ChatMessage` has been removed. "
|
||||
"Use the `text` property to access the textual value. "
|
||||
"For more information about the new API and how to migrate, see the documentation: "
|
||||
"https://docs.haystack.deepset.ai/docs/data-classes#chatmessage"
|
||||
)
|
||||
raise AttributeError(msg)
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._content)
|
||||
|
||||
@property
|
||||
def role(self) -> ChatRole:
|
||||
"""
|
||||
Returns the role of the entity sending the message.
|
||||
"""
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def meta(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the metadata associated with the message.
|
||||
"""
|
||||
return self._meta
|
||||
|
||||
@property
|
||||
def name(self) -> Optional[str]:
|
||||
"""
|
||||
Returns the name associated with the message.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def texts(self) -> List[str]:
|
||||
"""
|
||||
Returns the list of all texts contained in the message.
|
||||
"""
|
||||
return [content.text for content in self._content if isinstance(content, TextContent)]
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
"""
|
||||
Returns the textual content of the message.
|
||||
Returns the first text contained in the message.
|
||||
"""
|
||||
# Currently, this property mirrors the `content` attribute. This will change in 2.9.0.
|
||||
# The current actual return type is str. We are using Optional[str] to be ready for 2.9.0,
|
||||
# when None will be a valid value for `text`.
|
||||
return object.__getattribute__(self, "content")
|
||||
if texts := self.texts:
|
||||
return texts[0]
|
||||
return None
|
||||
|
||||
def __getattribute__(self, name):
|
||||
# this method is reimplemented to warn about the deprecation of the `content` attribute
|
||||
if name == "content":
|
||||
msg = (
|
||||
"The `content` attribute of `ChatMessage` will be removed in Haystack 2.9.0. "
|
||||
"Use the `text` property to access the textual value."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return object.__getattribute__(self, name)
|
||||
@property
|
||||
def tool_calls(self) -> List[ToolCall]:
|
||||
"""
|
||||
Returns the list of all Tool calls contained in the message.
|
||||
"""
|
||||
return [content for content in self._content if isinstance(content, ToolCall)]
|
||||
|
||||
def is_from(self, role: ChatRole) -> bool:
|
||||
@property
|
||||
def tool_call(self) -> Optional[ToolCall]:
|
||||
"""
|
||||
Returns the first Tool call contained in the message.
|
||||
"""
|
||||
if tool_calls := self.tool_calls:
|
||||
return tool_calls[0]
|
||||
return None
|
||||
|
||||
@property
|
||||
def tool_call_results(self) -> List[ToolCallResult]:
|
||||
"""
|
||||
Returns the list of all Tool call results contained in the message.
|
||||
"""
|
||||
return [content for content in self._content if isinstance(content, ToolCallResult)]
|
||||
|
||||
@property
|
||||
def tool_call_result(self) -> Optional[ToolCallResult]:
|
||||
"""
|
||||
Returns the first Tool call result contained in the message.
|
||||
"""
|
||||
if tool_call_results := self.tool_call_results:
|
||||
return tool_call_results[0]
|
||||
return None
|
||||
|
||||
def is_from(self, role: Union[ChatRole, str]) -> bool:
|
||||
"""
|
||||
Check if the message is from a specific role.
|
||||
|
||||
:param role: The role to check against.
|
||||
:returns: True if the message is from the specified role, False otherwise.
|
||||
"""
|
||||
return self.role == role
|
||||
if isinstance(role, str):
|
||||
role = ChatRole.from_str(role)
|
||||
return self._role == role
|
||||
|
||||
@classmethod
|
||||
def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage":
|
||||
"""
|
||||
Create a message from the assistant.
|
||||
|
||||
:param content: The text content of the message.
|
||||
:param meta: Additional metadata associated with the message.
|
||||
:returns: A new ChatMessage instance.
|
||||
"""
|
||||
return cls(content, ChatRole.ASSISTANT, None, meta or {})
|
||||
|
||||
@classmethod
|
||||
def from_user(cls, content: str) -> "ChatMessage":
|
||||
def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage":
|
||||
"""
|
||||
Create a message from the user.
|
||||
|
||||
:param content: The text content of the message.
|
||||
:param text: The text content of the message.
|
||||
:param meta: Additional metadata associated with the message.
|
||||
:param name: An optional name for the participant. This field is only supported by OpenAI.
|
||||
:returns: A new ChatMessage instance.
|
||||
"""
|
||||
return cls(content, ChatRole.USER, None)
|
||||
return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}, _name=name)
|
||||
|
||||
@classmethod
|
||||
def from_system(cls, content: str) -> "ChatMessage":
|
||||
def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage":
|
||||
"""
|
||||
Create a message from the system.
|
||||
|
||||
:param content: The text content of the message.
|
||||
:param text: The text content of the message.
|
||||
:param meta: Additional metadata associated with the message.
|
||||
:param name: An optional name for the participant. This field is only supported by OpenAI.
|
||||
:returns: A new ChatMessage instance.
|
||||
"""
|
||||
return cls(content, ChatRole.SYSTEM, None)
|
||||
return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}, _name=name)
|
||||
|
||||
@classmethod
|
||||
def from_assistant(
|
||||
cls,
|
||||
text: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_calls: Optional[List[ToolCall]] = None,
|
||||
) -> "ChatMessage":
|
||||
"""
|
||||
Create a message from the assistant.
|
||||
|
||||
:param text: The text content of the message.
|
||||
:param meta: Additional metadata associated with the message.
|
||||
:param tool_calls: The Tool calls to include in the message.
|
||||
:param name: An optional name for the participant. This field is only supported by OpenAI.
|
||||
:returns: A new ChatMessage instance.
|
||||
"""
|
||||
content: List[ChatMessageContentT] = []
|
||||
if text is not None:
|
||||
content.append(TextContent(text=text))
|
||||
if tool_calls:
|
||||
content.extend(tool_calls)
|
||||
|
||||
return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}, _name=name)
|
||||
|
||||
@classmethod
|
||||
def from_tool(
|
||||
cls, tool_result: str, origin: ToolCall, error: bool = False, meta: Optional[Dict[str, Any]] = None
|
||||
) -> "ChatMessage":
|
||||
"""
|
||||
Create a message from a Tool.
|
||||
|
||||
:param tool_result: The result of the Tool invocation.
|
||||
:param origin: The Tool call that produced this result.
|
||||
:param error: Whether the Tool invocation resulted in an error.
|
||||
:param meta: Additional metadata associated with the message.
|
||||
:returns: A new ChatMessage instance.
|
||||
"""
|
||||
return cls(
|
||||
_role=ChatRole.TOOL,
|
||||
_content=[ToolCallResult(result=tool_result, origin=origin, error=error)],
|
||||
_meta=meta or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(cls, content: str, name: str) -> "ChatMessage":
|
||||
"""
|
||||
Create a message from a function call.
|
||||
Create a message from a function call. Deprecated in favor of `from_tool`.
|
||||
|
||||
:param content: The text content of the message.
|
||||
:param name: The name of the function being called.
|
||||
:returns: A new ChatMessage instance.
|
||||
"""
|
||||
return cls(content, ChatRole.FUNCTION, name)
|
||||
msg = (
|
||||
"The `from_function` method is deprecated and will be removed in version 2.10.0. "
|
||||
"Its behavior has changed: it now attempts to convert legacy function messages to tool messages. "
|
||||
"This conversion is not guaranteed to succeed in all scenarios. "
|
||||
"Please migrate to `ChatMessage.from_tool` and carefully verify the results if you "
|
||||
"continue to use this method."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
|
||||
return cls.from_tool(content, ToolCall(id=None, tool_name=name, arguments={}), error=False)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -111,10 +325,23 @@ class ChatMessage:
|
||||
:returns:
|
||||
Serialized version of the object.
|
||||
"""
|
||||
data = asdict(self)
|
||||
data["role"] = self.role.value
|
||||
serialized: Dict[str, Any] = {}
|
||||
serialized["_role"] = self._role.value
|
||||
serialized["_meta"] = self._meta
|
||||
serialized["_name"] = self._name
|
||||
content: List[Dict[str, Any]] = []
|
||||
for part in self._content:
|
||||
if isinstance(part, TextContent):
|
||||
content.append({"text": part.text})
|
||||
elif isinstance(part, ToolCall):
|
||||
content.append({"tool_call": asdict(part)})
|
||||
elif isinstance(part, ToolCallResult):
|
||||
content.append({"tool_call_result": asdict(part)})
|
||||
else:
|
||||
raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.")
|
||||
|
||||
return data
|
||||
serialized["_content"] = content
|
||||
return serialized
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage":
|
||||
@ -126,6 +353,31 @@ class ChatMessage:
|
||||
:returns:
|
||||
The created object.
|
||||
"""
|
||||
data["role"] = ChatRole(data["role"])
|
||||
if any(param in data for param in LEGACY_INIT_PARAMETERS):
|
||||
raise TypeError(
|
||||
"The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. "
|
||||
"For more information about the new API and how to migrate, see the documentation: "
|
||||
"https://docs.haystack.deepset.ai/docs/data-classes#chatmessage"
|
||||
)
|
||||
|
||||
data["_role"] = ChatRole(data["_role"])
|
||||
|
||||
content: List[ChatMessageContentT] = []
|
||||
|
||||
for part in data["_content"]:
|
||||
if "text" in part:
|
||||
content.append(TextContent(text=part["text"]))
|
||||
elif "tool_call" in part:
|
||||
content.append(ToolCall(**part["tool_call"]))
|
||||
elif "tool_call_result" in part:
|
||||
result = part["tool_call_result"]["result"]
|
||||
origin = ToolCall(**part["tool_call_result"]["origin"])
|
||||
error = part["tool_call_result"]["error"]
|
||||
tcr = ToolCallResult(result=result, origin=origin, error=error)
|
||||
content.append(tcr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content in serialized ChatMessage: `{part}`")
|
||||
|
||||
data["_content"] = content
|
||||
|
||||
return cls(**data)
|
||||
|
||||
23
releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml
Normal file
23
releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
---
|
||||
highlights: >
|
||||
We are introducing a refactored ChatMessage dataclass. It is more flexible, future-proof, and compatible with
|
||||
different types of content: text, tool calls, tool calls results.
|
||||
For information about the new API and how to migrate, see the documentation:
|
||||
https://docs.haystack.deepset.ai/docs/data-classes#chatmessage
|
||||
upgrade:
|
||||
- |
|
||||
The refactoring of the ChatMessage dataclass includes some breaking changes, involving ChatMessage creation and
|
||||
accessing attributes. If you have a Pipeline containing a ChatPromptBuilder, serialized using Haystack<2.9.0,
|
||||
deserialization may break.
|
||||
For detailed information about the changes and how to migrate, see the documentation:
|
||||
https://docs.haystack.deepset.ai/docs/data-classes#chatmessage
|
||||
features:
|
||||
- |
|
||||
Changed the ChatMessage dataclass to support different types of content, including tool calls, and tool call
|
||||
results.
|
||||
deprecations:
|
||||
- |
|
||||
The function role and ChatMessage.from_function class method have been deprecated and will be removed in
|
||||
Haystack 2.10.0. ChatMessage.from_function also attempts to produce a valid tool message.
|
||||
For more information, see the documentation:
|
||||
https://docs.haystack.deepset.ai/docs/data-classes#chatmessage
|
||||
@ -13,8 +13,8 @@ class TestChatPromptBuilder:
|
||||
def test_init(self):
|
||||
builder = ChatPromptBuilder(
|
||||
template=[
|
||||
ChatMessage.from_user(content="This is a {{ variable }}"),
|
||||
ChatMessage.from_system(content="This is a {{ variable2 }}"),
|
||||
ChatMessage.from_user("This is a {{ variable }}"),
|
||||
ChatMessage.from_system("This is a {{ variable2 }}"),
|
||||
]
|
||||
)
|
||||
assert builder.required_variables == []
|
||||
@ -531,8 +531,13 @@ class TestChatPromptBuilderDynamic:
|
||||
"type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder",
|
||||
"init_parameters": {
|
||||
"template": [
|
||||
{"content": "text and {var}", "role": "user", "name": None, "meta": {}},
|
||||
{"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}},
|
||||
{"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None},
|
||||
{
|
||||
"_content": [{"text": "content {required_var}"}],
|
||||
"_role": "assistant",
|
||||
"_meta": {},
|
||||
"_name": None,
|
||||
},
|
||||
],
|
||||
"variables": ["var", "required_var"],
|
||||
"required_variables": ["required_var"],
|
||||
@ -545,8 +550,13 @@ class TestChatPromptBuilderDynamic:
|
||||
"type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder",
|
||||
"init_parameters": {
|
||||
"template": [
|
||||
{"content": "text and {var}", "role": "user", "name": None, "meta": {}},
|
||||
{"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}},
|
||||
{"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None},
|
||||
{
|
||||
"_content": [{"text": "content {required_var}"}],
|
||||
"_role": "assistant",
|
||||
"_meta": {},
|
||||
"_name": None,
|
||||
},
|
||||
],
|
||||
"variables": ["var", "required_var"],
|
||||
"required_variables": ["required_var"],
|
||||
|
||||
@ -68,13 +68,6 @@ def test_convert_message_to_hfapi_format():
|
||||
message = ChatMessage.from_user("I have a question")
|
||||
assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"}
|
||||
|
||||
message = ChatMessage.from_function("Function call", "function_name")
|
||||
assert _convert_message_to_hfapi_format(message) == {
|
||||
"role": "function",
|
||||
"content": "Function call",
|
||||
"name": "function_name",
|
||||
}
|
||||
|
||||
|
||||
class TestHuggingFaceAPIGenerator:
|
||||
def test_init_invalid_api_type(self):
|
||||
|
||||
@ -14,10 +14,3 @@ def test_convert_message_to_openai_format():
|
||||
|
||||
message = ChatMessage.from_user("I have a question")
|
||||
assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"}
|
||||
|
||||
message = ChatMessage.from_function("Function call", "function_name")
|
||||
assert _convert_message_to_openai_format(message) == {
|
||||
"role": "function",
|
||||
"content": "Function call",
|
||||
"name": "function_name",
|
||||
}
|
||||
|
||||
@ -349,7 +349,7 @@ class TestRouter:
|
||||
]
|
||||
router = ConditionalRouter(routes, unsafe=True)
|
||||
streams = [1]
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
message = ChatMessage.from_user("This is a message")
|
||||
res = router.run(streams=streams, message=message)
|
||||
assert res == {"message": message}
|
||||
|
||||
@ -370,7 +370,7 @@ class TestRouter:
|
||||
]
|
||||
router = ConditionalRouter(routes, validate_output_type=True)
|
||||
streams = [1]
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
message = ChatMessage.from_user("This is a message")
|
||||
with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"):
|
||||
router.run(streams=streams, message=message)
|
||||
|
||||
@ -391,7 +391,7 @@ class TestRouter:
|
||||
]
|
||||
router = ConditionalRouter(routes, unsafe=True, validate_output_type=True)
|
||||
streams = [1]
|
||||
message = ChatMessage.from_user(content="This is a message")
|
||||
message = ChatMessage.from_user("This is a message")
|
||||
res = router.run(streams=streams, message=message)
|
||||
assert isinstance(res["message"], ChatMessage)
|
||||
|
||||
|
||||
@ -1657,7 +1657,7 @@ def that_is_a_simple_agent():
|
||||
class ToolExtractor:
|
||||
@component.output_types(output=List[str])
|
||||
def run(self, messages: List[ChatMessage]):
|
||||
prompt: str = messages[-1].content
|
||||
prompt: str = messages[-1].text
|
||||
lines = prompt.strip().split("\n")
|
||||
for line in reversed(lines):
|
||||
pattern = r"Action:\s*(\w+)\[(.*?)\]"
|
||||
@ -1678,14 +1678,14 @@ def that_is_a_simple_agent():
|
||||
|
||||
@component.output_types(output=List[ChatMessage])
|
||||
def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]):
|
||||
content = current_prompt[-1].content + replies[-1].content + self._suffix
|
||||
content = current_prompt[-1].text + replies[-1].text + self._suffix
|
||||
return {"output": [ChatMessage.from_user(content)]}
|
||||
|
||||
@component
|
||||
class SearchOutputAdapter:
|
||||
@component.output_types(output=List[ChatMessage])
|
||||
def run(self, replies: List[ChatMessage]):
|
||||
content = f"Observation: {replies[-1].content}\n"
|
||||
content = f"Observation: {replies[-1].text}\n"
|
||||
return {"output": [ChatMessage.from_assistant(content)]}
|
||||
|
||||
pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator())
|
||||
|
||||
@ -4,64 +4,240 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.dataclasses import ChatMessage, ChatRole
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent
|
||||
from haystack.components.generators.openai_utils import _convert_message_to_openai_format
|
||||
|
||||
|
||||
def test_tool_call_init():
|
||||
tc = ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
|
||||
assert tc.id == "123"
|
||||
assert tc.tool_name == "mytool"
|
||||
assert tc.arguments == {"a": 1}
|
||||
|
||||
|
||||
def test_tool_call_result_init():
|
||||
tcr = ToolCallResult(result="result", origin=ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), error=True)
|
||||
assert tcr.result == "result"
|
||||
assert tcr.origin == ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
|
||||
assert tcr.error
|
||||
|
||||
|
||||
def test_text_content_init():
|
||||
tc = TextContent(text="Hello")
|
||||
assert tc.text == "Hello"
|
||||
|
||||
|
||||
def test_from_assistant_with_valid_content():
|
||||
content = "Hello, how can I assist you?"
|
||||
message = ChatMessage.from_assistant(content)
|
||||
assert message.content == content
|
||||
assert message.text == content
|
||||
text = "Hello, how can I assist you?"
|
||||
message = ChatMessage.from_assistant(text)
|
||||
|
||||
assert message.role == ChatRole.ASSISTANT
|
||||
assert message._content == [TextContent(text)]
|
||||
assert message.name is None
|
||||
|
||||
assert message.text == text
|
||||
assert message.texts == [text]
|
||||
|
||||
assert not message.tool_calls
|
||||
assert not message.tool_call
|
||||
assert not message.tool_call_results
|
||||
assert not message.tool_call_result
|
||||
|
||||
|
||||
def test_from_assistant_with_tool_calls():
|
||||
tool_calls = [
|
||||
ToolCall(id="123", tool_name="mytool", arguments={"a": 1}),
|
||||
ToolCall(id="456", tool_name="mytool2", arguments={"b": 2}),
|
||||
]
|
||||
|
||||
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||
|
||||
assert message.role == ChatRole.ASSISTANT
|
||||
assert message._content == tool_calls
|
||||
|
||||
assert message.tool_calls == tool_calls
|
||||
assert message.tool_call == tool_calls[0]
|
||||
|
||||
assert not message.texts
|
||||
assert not message.text
|
||||
assert not message.tool_call_results
|
||||
assert not message.tool_call_result
|
||||
|
||||
|
||||
def test_from_user_with_valid_content():
|
||||
content = "I have a question."
|
||||
message = ChatMessage.from_user(content)
|
||||
assert message.content == content
|
||||
assert message.text == content
|
||||
text = "I have a question."
|
||||
message = ChatMessage.from_user(text=text)
|
||||
|
||||
assert message.role == ChatRole.USER
|
||||
assert message._content == [TextContent(text)]
|
||||
assert message.name is None
|
||||
|
||||
assert message.text == text
|
||||
assert message.texts == [text]
|
||||
|
||||
assert not message.tool_calls
|
||||
assert not message.tool_call
|
||||
assert not message.tool_call_results
|
||||
assert not message.tool_call_result
|
||||
|
||||
|
||||
def test_from_user_with_name():
|
||||
text = "I have a question."
|
||||
message = ChatMessage.from_user(text=text, name="John")
|
||||
|
||||
assert message.name == "John"
|
||||
assert message.role == ChatRole.USER
|
||||
assert message._content == [TextContent(text)]
|
||||
|
||||
|
||||
def test_from_system_with_valid_content():
|
||||
content = "System message."
|
||||
message = ChatMessage.from_system(content)
|
||||
assert message.content == content
|
||||
assert message.text == content
|
||||
text = "I have a question."
|
||||
message = ChatMessage.from_system(text=text)
|
||||
|
||||
assert message.role == ChatRole.SYSTEM
|
||||
assert message._content == [TextContent(text)]
|
||||
|
||||
assert message.text == text
|
||||
assert message.texts == [text]
|
||||
|
||||
assert not message.tool_calls
|
||||
assert not message.tool_call
|
||||
assert not message.tool_call_results
|
||||
assert not message.tool_call_result
|
||||
|
||||
|
||||
def test_with_empty_content():
|
||||
message = ChatMessage.from_user("")
|
||||
assert message.content == ""
|
||||
assert message.text == ""
|
||||
assert message.role == ChatRole.USER
|
||||
def test_from_tool_with_valid_content():
|
||||
tool_result = "Tool result"
|
||||
origin = ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
|
||||
message = ChatMessage.from_tool(tool_result, origin, error=False)
|
||||
|
||||
tcr = ToolCallResult(result=tool_result, origin=origin, error=False)
|
||||
|
||||
assert message._content == [tcr]
|
||||
assert message.role == ChatRole.TOOL
|
||||
|
||||
assert message.tool_call_result == tcr
|
||||
assert message.tool_call_results == [tcr]
|
||||
|
||||
assert not message.tool_calls
|
||||
assert not message.tool_call
|
||||
assert not message.texts
|
||||
assert not message.text
|
||||
|
||||
|
||||
def test_from_function_with_empty_name():
|
||||
content = "Function call"
|
||||
message = ChatMessage.from_function(content, "")
|
||||
assert message.content == content
|
||||
assert message.text == content
|
||||
assert message.name == ""
|
||||
assert message.role == ChatRole.FUNCTION
|
||||
def test_multiple_text_segments():
|
||||
texts = [TextContent(text="Hello"), TextContent(text="World")]
|
||||
message = ChatMessage(_role=ChatRole.USER, _content=texts)
|
||||
|
||||
assert message.texts == ["Hello", "World"]
|
||||
assert len(message) == 2
|
||||
|
||||
|
||||
def test_to_openai_format():
|
||||
message = ChatMessage.from_system("You are good assistant")
|
||||
assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"}
|
||||
def test_mixed_content():
|
||||
content = [TextContent(text="Hello"), ToolCall(id="123", tool_name="mytool", arguments={"a": 1})]
|
||||
|
||||
message = ChatMessage.from_user("I have a question")
|
||||
assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"}
|
||||
message = ChatMessage(_role=ChatRole.ASSISTANT, _content=content)
|
||||
|
||||
message = ChatMessage.from_function("Function call", "function_name")
|
||||
assert _convert_message_to_openai_format(message) == {
|
||||
"role": "function",
|
||||
"content": "Function call",
|
||||
"name": "function_name",
|
||||
assert len(message) == 2
|
||||
assert message.texts == ["Hello"]
|
||||
assert message.text == "Hello"
|
||||
|
||||
assert message.tool_calls == [content[1]]
|
||||
assert message.tool_call == content[1]
|
||||
|
||||
|
||||
def test_from_function():
|
||||
# check warning is raised
|
||||
with pytest.warns():
|
||||
message = ChatMessage.from_function("Result of function invocation", "my_function")
|
||||
|
||||
assert message.role == ChatRole.TOOL
|
||||
assert message.tool_call_result == ToolCallResult(
|
||||
result="Result of function invocation",
|
||||
origin=ToolCall(id=None, tool_name="my_function", arguments={}),
|
||||
error=False,
|
||||
)
|
||||
|
||||
|
||||
def test_serde():
|
||||
# the following message is created just for testing purposes and does not make sense in a real use case
|
||||
|
||||
role = ChatRole.ASSISTANT
|
||||
|
||||
text_content = TextContent(text="Hello")
|
||||
tool_call = ToolCall(id="123", tool_name="mytool", arguments={"a": 1})
|
||||
tool_call_result = ToolCallResult(result="result", origin=tool_call, error=False)
|
||||
meta = {"some": "info"}
|
||||
|
||||
message = ChatMessage(_role=role, _content=[text_content, tool_call, tool_call_result], _meta=meta)
|
||||
|
||||
serialized_message = message.to_dict()
|
||||
assert serialized_message == {
|
||||
"_content": [
|
||||
{"text": "Hello"},
|
||||
{"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}},
|
||||
{
|
||||
"tool_call_result": {
|
||||
"result": "result",
|
||||
"error": False,
|
||||
"origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}},
|
||||
}
|
||||
},
|
||||
],
|
||||
"_role": "assistant",
|
||||
"_name": None,
|
||||
"_meta": {"some": "info"},
|
||||
}
|
||||
|
||||
deserialized_message = ChatMessage.from_dict(serialized_message)
|
||||
assert deserialized_message == message
|
||||
|
||||
|
||||
def test_to_dict_with_invalid_content_type():
|
||||
text_content = TextContent(text="Hello")
|
||||
invalid_content = "invalid"
|
||||
|
||||
message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[text_content, invalid_content])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
message.to_dict()
|
||||
|
||||
|
||||
def test_from_dict_with_invalid_content_type():
|
||||
data = {"_role": "assistant", "_content": [{"text": "Hello"}, "invalid"]}
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_dict(data)
|
||||
|
||||
data = {"_role": "assistant", "_content": [{"text": "Hello"}, {"invalid": "invalid"}]}
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_dict(data)
|
||||
|
||||
|
||||
def test_from_dict_with_legacy_init_parameters():
|
||||
with pytest.raises(TypeError):
|
||||
ChatMessage.from_dict({"role": "user", "content": "This is a message"})
|
||||
|
||||
|
||||
def test_chat_message_content_attribute_removed():
|
||||
message = ChatMessage.from_user(text="This is a message")
|
||||
with pytest.raises(AttributeError):
|
||||
message.content
|
||||
|
||||
|
||||
def test_chat_message_init_parameters_removed():
|
||||
with pytest.raises(TypeError):
|
||||
ChatMessage(role="irrelevant", content="This is a message")
|
||||
|
||||
|
||||
def test_chat_message_init_content_parameter_type():
|
||||
with pytest.raises(TypeError):
|
||||
ChatMessage(ChatRole.USER, "This is a message")
|
||||
|
||||
|
||||
def test_chat_message_function_role_deprecated():
|
||||
with pytest.warns(DeprecationWarning):
|
||||
ChatMessage(ChatRole.FUNCTION, TextContent("This is a message"))
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_chat_templating_on_chat_message():
|
||||
@ -93,40 +269,3 @@ def test_apply_custom_chat_templating_on_chat_message():
|
||||
formatted_messages, chat_template=anthropic_template, tokenize=False
|
||||
)
|
||||
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
content = "content"
|
||||
role = "user"
|
||||
meta = {"some": "some"}
|
||||
|
||||
message = ChatMessage.from_user(content)
|
||||
message.meta.update(meta)
|
||||
|
||||
assert message.text == content
|
||||
assert message.to_dict() == {"content": content, "role": role, "name": None, "meta": meta}
|
||||
|
||||
|
||||
def test_from_dict():
|
||||
assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage.from_user(
|
||||
"text"
|
||||
)
|
||||
|
||||
|
||||
def test_from_dict_with_meta():
|
||||
data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
|
||||
assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"})
|
||||
|
||||
|
||||
def test_content_deprecation_warning(recwarn):
|
||||
message = ChatMessage.from_user("my message")
|
||||
|
||||
# accessing the content attribute triggers the deprecation warning
|
||||
_ = message.content
|
||||
assert len(recwarn) == 1
|
||||
wrn = recwarn.pop(DeprecationWarning)
|
||||
assert "`content` attribute" in wrn.message.args[0]
|
||||
|
||||
# accessing the text property does not trigger a warning
|
||||
assert message.text == "my message"
|
||||
assert len(recwarn) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user