diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index c84f68afa..13bbfbe55 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -285,13 +285,12 @@ class OpenAIChatGenerator: else: assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request." completions = [ - self._convert_chat_completion_to_chat_message(chat_completion, choice) - for choice in chat_completion.choices + _convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices ] # before returning, do post-processing of the completions for message in completions: - self._check_finish_reason(message.meta) + _check_finish_reason(message.meta) return {"replies": completions} @@ -362,13 +361,12 @@ class OpenAIChatGenerator: else: assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request." completions = [ - self._convert_chat_completion_to_chat_message(chat_completion, choice) - for choice in chat_completion.choices + _convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices ] # before returning, do post-processing of the completions for message in completions: - self._check_finish_reason(message.meta) + _check_finish_reason(message.meta) return {"replies": completions} @@ -419,197 +417,198 @@ class OpenAIChatGenerator: } def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]: + component_info = ComponentInfo.from_component(self) chunks: List[StreamingChunk] = [] - chunk = None - chunk_delta: StreamingChunk - for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) + chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info) chunks.append(chunk_delta) callback(chunk_delta) - return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] async def _handle_async_stream_response( self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT ) -> List[ChatMessage]: + component_info = ComponentInfo.from_component(self) chunks: List[StreamingChunk] = [] - chunk = None - chunk_delta: StreamingChunk - async for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) + chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info) chunks.append(chunk_delta) await callback(chunk_delta) - return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] - def _check_finish_reason(self, meta: Dict[str, Any]) -> None: - if meta["finish_reason"] == "length": + +def _check_finish_reason(meta: Dict[str, Any]) -> None: + if meta["finish_reason"] == "length": + logger.warning( + "The completion for index {index} has been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions.", + index=meta["index"], + finish_reason=meta["finish_reason"], + ) + if meta["finish_reason"] == "content_filter": + logger.warning( + "The completion for index {index} has been truncated due to the content filter.", + index=meta["index"], + finish_reason=meta["finish_reason"], + ) + + +def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage: + """ + Connects the streaming chunks into a single ChatMessage. + + :param chunks: The list of all `StreamingChunk` objects. + + :returns: The ChatMessage. + """ + text = "".join([chunk.content for chunk in chunks]) + tool_calls = [] + + # Process tool calls if present in any chunk + tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index + for chunk_payload in chunks: + tool_calls_meta = chunk_payload.meta.get("tool_calls") + if tool_calls_meta is not None: + for delta in tool_calls_meta: + # We use the index of the tool call to track it across chunks since the ID is not always provided + if delta.index not in tool_call_data: + tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""} + + # Save the ID if present + if delta.id is not None: + tool_call_data[delta.index]["id"] = delta.id + + if delta.function is not None: + if delta.function.name is not None: + tool_call_data[delta.index]["name"] += delta.function.name + if delta.function.arguments is not None: + tool_call_data[delta.index]["arguments"] += delta.function.arguments + + # Convert accumulated tool call data into ToolCall objects + for call_data in tool_call_data.values(): + try: + arguments = json.loads(call_data["arguments"]) + tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) + except json.JSONDecodeError: logger.warning( - "The completion for index {index} has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions.", - index=meta["index"], - finish_reason=meta["finish_reason"], - ) - if meta["finish_reason"] == "content_filter": - logger.warning( - "The completion for index {index} has been truncated due to the content filter.", - index=meta["index"], - finish_reason=meta["finish_reason"], + "OpenAI returned a malformed JSON string for tool call arguments. This tool call " + "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=call_data["id"], + _name=call_data["name"], + _arguments=call_data["arguments"], ) - def _convert_streaming_chunks_to_chat_message( - self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk] - ) -> ChatMessage: - """ - Connects the streaming chunks into a single ChatMessage. + # finish_reason can appear in different places so we look for the last one + finish_reasons = [ + chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None + ] + finish_reason = finish_reasons[-1] if finish_reasons else None - :param last_chunk: The last chunk returned by the OpenAI API. - :param chunks: The list of all `StreamingChunk` objects. + meta = { + "model": chunks[-1].meta.get("model"), + "index": 0, + "finish_reason": finish_reason, + "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received + "usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available + } - :returns: The ChatMessage. - """ - text = "".join([chunk.content for chunk in chunks]) - tool_calls = [] + return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) - # Process tool calls if present in any chunk - tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index - for chunk_payload in chunks: - tool_calls_meta = chunk_payload.meta.get("tool_calls") - if tool_calls_meta is not None: - for delta in tool_calls_meta: - # We use the index of the tool call to track it across chunks since the ID is not always provided - if delta.index not in tool_call_data: - tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""} - # Save the ID if present - if delta.id is not None: - tool_call_data[delta.index]["id"] = delta.id +def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: Choice) -> ChatMessage: + """ + Converts the non-streaming response from the OpenAI API to a ChatMessage. - if delta.function is not None: - if delta.function.name is not None: - tool_call_data[delta.index]["name"] += delta.function.name - if delta.function.arguments is not None: - tool_call_data[delta.index]["arguments"] += delta.function.arguments - - # Convert accumulated tool call data into ToolCall objects - for call_data in tool_call_data.values(): + :param completion: The completion returned by the OpenAI API. + :param choice: The choice returned by the OpenAI API. + :return: The ChatMessage. + """ + message: ChatCompletionMessage = choice.message + text = message.content + tool_calls = [] + if openai_tool_calls := message.tool_calls: + for openai_tc in openai_tool_calls: + arguments_str = openai_tc.function.arguments try: - arguments = json.loads(call_data["arguments"]) - tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) + arguments = json.loads(arguments_str) + tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments)) except json.JSONDecodeError: logger.warning( "OpenAI returned a malformed JSON string for tool call arguments. This tool call " "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", - _id=call_data["id"], - _name=call_data["name"], - _arguments=call_data["arguments"], + _id=openai_tc.id, + _name=openai_tc.function.name, + _arguments=arguments_str, ) - # finish_reason can appear in different places so we look for the last one - finish_reasons = [ - chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None - ] - finish_reason = finish_reasons[-1] if finish_reasons else None + chat_message = ChatMessage.from_assistant( + text=text, + tool_calls=tool_calls, + meta={ + "model": completion.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "usage": _serialize_usage(completion.usage), + }, + ) + return chat_message - meta = { - "model": last_chunk.model, - "index": 0, - "finish_reason": finish_reason, - "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received - "usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available - } - return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) +def _convert_chat_completion_chunk_to_streaming_chunk( + chunk: ChatCompletionChunk, component_info: Optional[ComponentInfo] = None +) -> StreamingChunk: + """ + Converts the streaming response chunk from the OpenAI API to a StreamingChunk. - def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage: - """ - Converts the non-streaming response from the OpenAI API to a ChatMessage. + :param chunk: The chunk returned by the OpenAI API. - :param completion: The completion returned by the OpenAI API. - :param choice: The choice returned by the OpenAI API. - :return: The ChatMessage. - """ - message: ChatCompletionMessage = choice.message - text = message.content - tool_calls = [] - if openai_tool_calls := message.tool_calls: - for openai_tc in openai_tool_calls: - arguments_str = openai_tc.function.arguments - try: - arguments = json.loads(arguments_str) - tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments)) - except json.JSONDecodeError: - logger.warning( - "OpenAI returned a malformed JSON string for tool call arguments. This tool call " - "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " - "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", - _id=openai_tc.id, - _name=openai_tc.function.name, - _arguments=arguments_str, - ) - - chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) - chat_message._meta.update( - { - "model": completion.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - "usage": self._serialize_usage(completion.usage), - } - ) - return chat_message - - def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: - """ - Converts the streaming response chunk from the OpenAI API to a StreamingChunk. - - :param chunk: The chunk returned by the OpenAI API. - - :returns: - The StreamingChunk. - """ - - # get the component name and type - component_info = ComponentInfo.from_component(self) - - # we stream the content of the chunk if it's not a tool or function call - # if there are no choices, return an empty chunk - if len(chunk.choices) == 0: - return StreamingChunk( - content="", - meta={"model": chunk.model, "received_at": datetime.now().isoformat()}, - component_info=component_info, - ) - - choice: ChunkChoice = chunk.choices[0] - content = choice.delta.content or "" - chunk_message = StreamingChunk(content, component_info=component_info) - - # but save the tool calls and function call in the meta if they are present - # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method - chunk_message.meta.update( - { + :returns: + The StreamingChunk. + """ + # Choices is empty on the very first chunk which provides role information (e.g. "assistant"). + # It is also empty if include_usage is set to True where the usage information is returned. + if len(chunk.choices) == 0: + return StreamingChunk( + content="", + meta={ "model": chunk.model, - "index": choice.index, - "tool_calls": choice.delta.tool_calls, - "finish_reason": choice.finish_reason, "received_at": datetime.now().isoformat(), - } + "usage": _serialize_usage(chunk.usage), + }, + component_info=component_info, ) - return chunk_message - def _serialize_usage(self, usage): - """Convert OpenAI usage object to serializable dict recursively""" - if hasattr(usage, "model_dump"): - return usage.model_dump() - elif hasattr(usage, "__dict__"): - return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")} - elif isinstance(usage, dict): - return {k: self._serialize_usage(v) for k, v in usage.items()} - elif isinstance(usage, list): - return [self._serialize_usage(item) for item in usage] - else: - return usage + choice: ChunkChoice = chunk.choices[0] + content = choice.delta.content or "" + + chunk_message = StreamingChunk( + content=content, + meta={ + "model": chunk.model, + "index": choice.index, + "tool_calls": choice.delta.tool_calls, + "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), + "usage": _serialize_usage(chunk.usage), + }, + component_info=component_info, + ) + return chunk_message + + +def _serialize_usage(usage): + """Convert OpenAI usage object to serializable dict recursively""" + if hasattr(usage, "model_dump"): + return usage.model_dump() + elif hasattr(usage, "__dict__"): + return {k: _serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")} + elif isinstance(usage, dict): + return {k: _serialize_usage(v) for k, v in usage.items()} + elif isinstance(usage, list): + return [_serialize_usage(item) for item in usage] + else: + return usage diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 02ba20c8e..5e1c35ad3 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -3,14 +3,25 @@ # SPDX-License-Identifier: Apache-2.0 import os -from datetime import datetime from typing import Any, Dict, List, Optional, Union from openai import OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback +from haystack.components.generators.chat.openai import ( + _check_finish_reason, + _convert_chat_completion_chunk_to_streaming_chunk, + _convert_chat_completion_to_chat_message, + _convert_streaming_chunks_to_chat_message, +) +from haystack.dataclasses import ( + ChatMessage, + ComponentInfo, + StreamingCallbackT, + StreamingChunk, + select_streaming_callback, +) from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.http_client import init_http_client @@ -230,129 +241,26 @@ class OpenAIGenerator: num_responses = generation_kwargs.pop("n", 1) if num_responses > 1: raise ValueError("Cannot stream multiple responses, please set n=1.") + + component_info = ComponentInfo.from_component(self) chunks: List[StreamingChunk] = [] - last_chunk: Optional[ChatCompletionChunk] = None - for chunk in completion: - if isinstance(chunk, ChatCompletionChunk): - last_chunk = chunk + chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk( + chunk=chunk, # type: ignore + component_info=component_info, + ) + chunks.append(chunk_delta) + streaming_callback(chunk_delta) - if chunk.choices: - chunk_delta: StreamingChunk = self._build_chunk(chunk) - chunks.append(chunk_delta) - streaming_callback(chunk_delta) - - assert last_chunk is not None - - completions = [self._create_message_from_chunks(last_chunk, chunks)] + completions = [_convert_streaming_chunks_to_chat_message(chunks=chunks)] elif isinstance(completion, ChatCompletion): - completions = [self._build_message(completion, choice) for choice in completion.choices] + completions = [ + _convert_chat_completion_to_chat_message(completion=completion, choice=choice) + for choice in completion.choices + ] # before returning, do post-processing of the completions for response in completions: - self._check_finish_reason(response) + _check_finish_reason(response.meta) return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]} - - def _serialize_usage(self, usage): - """Convert OpenAI usage object to serializable dict recursively""" - if hasattr(usage, "model_dump"): - return usage.model_dump() - elif hasattr(usage, "__dict__"): - return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")} - elif isinstance(usage, dict): - return {k: self._serialize_usage(v) for k, v in usage.items()} - elif isinstance(usage, list): - return [self._serialize_usage(item) for item in usage] - else: - return usage - - def _create_message_from_chunks( - self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk] - ) -> ChatMessage: - """ - Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk. - """ - complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in streamed_chunks])) - finish_reason = streamed_chunks[-1].meta["finish_reason"] - complete_response.meta.update( - { - "model": completion_chunk.model, - "index": 0, - "finish_reason": finish_reason, - "completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received - "usage": self._serialize_usage(completion_chunk.usage), - } - ) - return complete_response - - def _build_message(self, completion: Any, choice: Any) -> ChatMessage: - """ - Converts the response from the OpenAI API to a ChatMessage. - - :param completion: - The completion returned by the OpenAI API. - :param choice: - The choice returned by the OpenAI API. - :returns: - The ChatMessage. - """ - # function or tools calls are not going to happen in non-chat generation - # as users can not send ChatMessage with function or tools calls - chat_message = ChatMessage.from_assistant(choice.message.content or "") - chat_message.meta.update( - { - "model": completion.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - "usage": self._serialize_usage(completion.usage), - } - ) - return chat_message - - @staticmethod - def _build_chunk(chunk: Any) -> StreamingChunk: - """ - Converts the response from the OpenAI API to a StreamingChunk. - - :param chunk: - The chunk returned by the OpenAI API. - :returns: - The StreamingChunk. - """ - choice = chunk.choices[0] - content = choice.delta.content or "" - chunk_message = StreamingChunk(content) - chunk_message.meta.update( - { - "model": chunk.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - "received_at": datetime.now().isoformat(), - } - ) - return chunk_message - - @staticmethod - def _check_finish_reason(message: ChatMessage) -> None: - """ - Check the `finish_reason` returned with the OpenAI completions. - - If the `finish_reason` is `length`, log a warning to the user. - - :param message: - The message returned by the LLM. - """ - if message.meta["finish_reason"] == "length": - logger.warning( - "The completion for index {index} has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions.", - index=message.meta["index"], - finish_reason=message.meta["finish_reason"], - ) - if message.meta["finish_reason"] == "content_filter": - logger.warning( - "The completion for index {index} has been truncated due to the content filter.", - index=message.meta["index"], - finish_reason=message.meta["finish_reason"], - ) diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index dcd4aa265..73b869f9b 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -24,7 +24,12 @@ from haystack.dataclasses import StreamingChunk, ComponentInfo from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools import ComponentTool, Tool -from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.components.generators.chat.openai import ( + OpenAIChatGenerator, + _check_finish_reason, + _convert_streaming_chunks_to_chat_message, + _convert_chat_completion_chunk_to_streaming_chunk, +) from haystack.tools.toolset import Toolset @@ -429,7 +434,6 @@ class TestOpenAIChatGenerator: def test_check_abnormal_completions(self, caplog): caplog.set_level(logging.INFO) - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) messages = [ ChatMessage.from_assistant( "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} @@ -438,7 +442,7 @@ class TestOpenAIChatGenerator: ] for m in messages: - component._check_finish_reason(m.meta) + _check_finish_reason(m.meta) # check truncation warning message_template = ( @@ -595,7 +599,6 @@ class TestOpenAIChatGenerator: assert message.meta["usage"]["completion_tokens"] == 47 def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) chunk = chat_completion_chunk.ChatCompletionChunk( id="chatcmpl-B2g1XYv1WzALulC5c8uLtJgvEB48I", choices=[ @@ -878,7 +881,7 @@ class TestOpenAIChatGenerator: ] # Convert chunks to a chat message - result = component._convert_streaming_chunks_to_chat_message(chunk, chunks) + result = _convert_streaming_chunks_to_chat_message(chunks=chunks) assert not result.texts assert not result.text @@ -899,7 +902,6 @@ class TestOpenAIChatGenerator: assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076" def test_convert_usage_chunk_to_streaming_chunk(self): - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) chunk = ChatCompletionChunk( id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw", choices=[], @@ -918,7 +920,7 @@ class TestOpenAIChatGenerator: prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), ), ) - result = component._convert_chat_completion_chunk_to_streaming_chunk(chunk) + result = _convert_chat_completion_chunk_to_streaming_chunk(chunk) assert result.content == "" assert result.meta["model"] == "gpt-4o-mini-2024-07-18" assert result.meta["received_at"] is not None diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index f33b21507..3e3448729 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -204,35 +204,6 @@ class TestOpenAIGenerator: assert len(response["replies"]) == 1 assert [isinstance(reply, str) for reply in response["replies"]] - def test_check_abnormal_completions(self, caplog): - caplog.set_level(logging.INFO) - component = OpenAIGenerator(api_key=Secret.from_token("test-api-key")) - - # underlying implementation uses ChatMessage objects so we have to use them here - messages: List[ChatMessage] = [] - for i, _ in enumerate(range(4)): - message = ChatMessage.from_assistant("Hello") - metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} - message.meta.update(metadata) - messages.append(message) - - for m in messages: - component._check_finish_reason(m) - - # check truncation warning - message_template = ( - "The completion for index {index} has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) - - for index in [1, 3]: - assert caplog.records[index].message == message_template.format(index=index) - - # check content filter warning - message_template = "The completion for index {index} has been truncated due to the content filter." - for index in [0, 2]: - assert caplog.records[index].message == message_template.format(index=index) - @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",