diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 932fc3345..09e7d9a1f 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -308,10 +308,6 @@ class OpenAIChatGenerator: } def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]: - print("callback") - print(callback) - print("-" * 100) - chunks: List[StreamingChunk] = [] chunk = None diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 0a028f101..a0016ac22 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -426,3 +426,81 @@ class ChatMessage: ) openai_msg["tool_calls"] = openai_tool_calls return openai_msg + + @staticmethod + def _validate_openai_message(message: Dict[str, Any]) -> None: + """ + Validate that a message dictionary follows OpenAI's Chat API format. + + :param message: The message dictionary to validate + :raises ValueError: If the message format is invalid + """ + if "role" not in message: + raise ValueError("The `role` field is required in the message dictionary.") + + role = message["role"] + content = message.get("content") + tool_calls = message.get("tool_calls") + + if role not in ["assistant", "user", "system", "developer", "tool"]: + raise ValueError(f"Unsupported role: {role}") + + if role == "assistant": + if not content and not tool_calls: + raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.") + if tool_calls: + for tc in tool_calls: + if "function" not in tc: + raise ValueError("Tool calls must contain the `function` field") + elif not content: + raise ValueError(f"The `content` field is required for {role} messages.") + + @classmethod + def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage": + """ + Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API. + + NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method + accepts messages without it to support shallow OpenAI-compatible APIs. + If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll + encounter validation errors. + + :param message: + The OpenAI dictionary to build the ChatMessage object. + :returns: + The created ChatMessage object. + + :raises ValueError: + If the message dictionary is missing required fields. + """ + cls._validate_openai_message(message) + + role = message["role"] + content = message.get("content") + name = message.get("name") + tool_calls = message.get("tool_calls") + tool_call_id = message.get("tool_call_id") + + if role == "assistant": + haystack_tool_calls = None + if tool_calls: + haystack_tool_calls = [] + for tc in tool_calls: + haystack_tc = ToolCall( + id=tc.get("id"), + tool_name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]), + ) + haystack_tool_calls.append(haystack_tc) + return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls) + + assert content is not None # ensured by _validate_openai_message, but we need to make mypy happy + + if role == "user": + return cls.from_user(text=content, name=name) + if role in ["system", "developer"]: + return cls.from_system(text=content, name=name) + + return cls.from_tool( + tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False + ) diff --git a/releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml b/releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml new file mode 100644 index 000000000..c51ec7f0f --- /dev/null +++ b/releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage` + from a dictionary in the format expected by OpenAI's Chat API. diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 2209af998..23a214ca2 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid(): message.to_openai_dict_format() +def test_from_openai_dict_format_user_message(): + openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "user" + assert message.text == "Hello, how are you?" + assert message.name == "John" + + +def test_from_openai_dict_format_system_message(): + openai_msg = {"role": "system", "content": "You are a helpful assistant"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "system" + assert message.text == "You are a helpful assistant" + + +def test_from_openai_dict_format_assistant_message_with_content(): + openai_msg = {"role": "assistant", "content": "I can help with that"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "assistant" + assert message.text == "I can help with that" + + +def test_from_openai_dict_format_assistant_message_with_tool_calls(): + openai_msg = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}], + } + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "assistant" + assert message.text is None + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call.id == "call_123" + assert tool_call.tool_name == "get_weather" + assert tool_call.arguments == {"location": "Berlin"} + + +def test_from_openai_dict_format_tool_message(): + openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "tool" + assert message.tool_call_result.result == "The weather is sunny" + assert message.tool_call_result.origin.id == "call_123" + + +def test_from_openai_dict_format_tool_without_id(): + openai_msg = {"role": "tool", "content": "The weather is sunny"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "tool" + assert message.tool_call_result.result == "The weather is sunny" + assert message.tool_call_result.origin.id is None + + +def test_from_openai_dict_format_missing_role(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"content": "test"}) + + +def test_from_openai_dict_format_missing_content(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"role": "user"}) + + +def test_from_openai_dict_format_invalid_tool_calls(): + openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]} + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format(openai_msg) + + +def test_from_openai_dict_format_unsupported_role(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"}) + + +def test_from_openai_dict_format_assistant_missing_content_and_tool_calls(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"}) + + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]