mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
feat: introduce class method to create ChatMessage from the OpenAI dictionary format (#8670)
* add ChatMessage.from_openai_dict_format * remove print * release note * improve docstring * separate validation logic * rm obvious comment
This commit is contained in:
parent
3ea128c962
commit
7b4d9ba86e
@ -308,10 +308,6 @@ class OpenAIChatGenerator:
|
||||
}
|
||||
|
||||
def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
|
||||
print("callback")
|
||||
print(callback)
|
||||
print("-" * 100)
|
||||
|
||||
chunks: List[StreamingChunk] = []
|
||||
chunk = None
|
||||
|
||||
|
||||
@ -426,3 +426,81 @@ class ChatMessage:
|
||||
)
|
||||
openai_msg["tool_calls"] = openai_tool_calls
|
||||
return openai_msg
|
||||
|
||||
@staticmethod
|
||||
def _validate_openai_message(message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate that a message dictionary follows OpenAI's Chat API format.
|
||||
|
||||
:param message: The message dictionary to validate
|
||||
:raises ValueError: If the message format is invalid
|
||||
"""
|
||||
if "role" not in message:
|
||||
raise ValueError("The `role` field is required in the message dictionary.")
|
||||
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
tool_calls = message.get("tool_calls")
|
||||
|
||||
if role not in ["assistant", "user", "system", "developer", "tool"]:
|
||||
raise ValueError(f"Unsupported role: {role}")
|
||||
|
||||
if role == "assistant":
|
||||
if not content and not tool_calls:
|
||||
raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.")
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
if "function" not in tc:
|
||||
raise ValueError("Tool calls must contain the `function` field")
|
||||
elif not content:
|
||||
raise ValueError(f"The `content` field is required for {role} messages.")
|
||||
|
||||
@classmethod
|
||||
def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage":
|
||||
"""
|
||||
Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API.
|
||||
|
||||
NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method
|
||||
accepts messages without it to support shallow OpenAI-compatible APIs.
|
||||
If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll
|
||||
encounter validation errors.
|
||||
|
||||
:param message:
|
||||
The OpenAI dictionary to build the ChatMessage object.
|
||||
:returns:
|
||||
The created ChatMessage object.
|
||||
|
||||
:raises ValueError:
|
||||
If the message dictionary is missing required fields.
|
||||
"""
|
||||
cls._validate_openai_message(message)
|
||||
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
name = message.get("name")
|
||||
tool_calls = message.get("tool_calls")
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
|
||||
if role == "assistant":
|
||||
haystack_tool_calls = None
|
||||
if tool_calls:
|
||||
haystack_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
haystack_tc = ToolCall(
|
||||
id=tc.get("id"),
|
||||
tool_name=tc["function"]["name"],
|
||||
arguments=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
haystack_tool_calls.append(haystack_tc)
|
||||
return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls)
|
||||
|
||||
assert content is not None # ensured by _validate_openai_message, but we need to make mypy happy
|
||||
|
||||
if role == "user":
|
||||
return cls.from_user(text=content, name=name)
|
||||
if role in ["system", "developer"]:
|
||||
return cls.from_system(text=content, name=name)
|
||||
|
||||
return cls.from_tool(
|
||||
tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False
|
||||
)
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage`
|
||||
from a dictionary in the format expected by OpenAI's Chat API.
|
||||
@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid():
|
||||
message.to_openai_dict_format()
|
||||
|
||||
|
||||
def test_from_openai_dict_format_user_message():
|
||||
openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"}
|
||||
message = ChatMessage.from_openai_dict_format(openai_msg)
|
||||
assert message.role.value == "user"
|
||||
assert message.text == "Hello, how are you?"
|
||||
assert message.name == "John"
|
||||
|
||||
|
||||
def test_from_openai_dict_format_system_message():
|
||||
openai_msg = {"role": "system", "content": "You are a helpful assistant"}
|
||||
message = ChatMessage.from_openai_dict_format(openai_msg)
|
||||
assert message.role.value == "system"
|
||||
assert message.text == "You are a helpful assistant"
|
||||
|
||||
|
||||
def test_from_openai_dict_format_assistant_message_with_content():
|
||||
openai_msg = {"role": "assistant", "content": "I can help with that"}
|
||||
message = ChatMessage.from_openai_dict_format(openai_msg)
|
||||
assert message.role.value == "assistant"
|
||||
assert message.text == "I can help with that"
|
||||
|
||||
|
||||
def test_from_openai_dict_format_assistant_message_with_tool_calls():
|
||||
openai_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}],
|
||||
}
|
||||
message = ChatMessage.from_openai_dict_format(openai_msg)
|
||||
assert message.role.value == "assistant"
|
||||
assert message.text is None
|
||||
assert len(message.tool_calls) == 1
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.tool_name == "get_weather"
|
||||
assert tool_call.arguments == {"location": "Berlin"}
|
||||
|
||||
|
||||
def test_from_openai_dict_format_tool_message():
|
||||
openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"}
|
||||
message = ChatMessage.from_openai_dict_format(openai_msg)
|
||||
assert message.role.value == "tool"
|
||||
assert message.tool_call_result.result == "The weather is sunny"
|
||||
assert message.tool_call_result.origin.id == "call_123"
|
||||
|
||||
|
||||
def test_from_openai_dict_format_tool_without_id():
|
||||
openai_msg = {"role": "tool", "content": "The weather is sunny"}
|
||||
message = ChatMessage.from_openai_dict_format(openai_msg)
|
||||
assert message.role.value == "tool"
|
||||
assert message.tool_call_result.result == "The weather is sunny"
|
||||
assert message.tool_call_result.origin.id is None
|
||||
|
||||
|
||||
def test_from_openai_dict_format_missing_role():
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_openai_dict_format({"content": "test"})
|
||||
|
||||
|
||||
def test_from_openai_dict_format_missing_content():
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_openai_dict_format({"role": "user"})
|
||||
|
||||
|
||||
def test_from_openai_dict_format_invalid_tool_calls():
|
||||
openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]}
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_openai_dict_format(openai_msg)
|
||||
|
||||
|
||||
def test_from_openai_dict_format_unsupported_role():
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"})
|
||||
|
||||
|
||||
def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_chat_templating_on_chat_message():
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user