refactor: Refactor openai generator (#9445)

* Refactor openai generator and chat generator to reusue same util methods

* Start fixing tests

* More fixes

* Fix mypy

* Fix
This commit is contained in:
Sebastian Husch Lee 2025-05-27 12:44:17 +02:00 committed by GitHub
parent 64def6d41b
commit db3d95b12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 196 additions and 316 deletions

View File

@ -285,13 +285,12 @@ class OpenAIChatGenerator:
else:
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
completions = [
self._convert_chat_completion_to_chat_message(chat_completion, choice)
for choice in chat_completion.choices
_convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices
]
# before returning, do post-processing of the completions
for message in completions:
self._check_finish_reason(message.meta)
_check_finish_reason(message.meta)
return {"replies": completions}
@ -362,13 +361,12 @@ class OpenAIChatGenerator:
else:
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
completions = [
self._convert_chat_completion_to_chat_message(chat_completion, choice)
for choice in chat_completion.choices
_convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices
]
# before returning, do post-processing of the completions
for message in completions:
self._check_finish_reason(message.meta)
_check_finish_reason(message.meta)
return {"replies": completions}
@ -419,197 +417,198 @@ class OpenAIChatGenerator:
}
def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]:
component_info = ComponentInfo.from_component(self)
chunks: List[StreamingChunk] = []
chunk = None
chunk_delta: StreamingChunk
for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info)
chunks.append(chunk_delta)
callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
async def _handle_async_stream_response(
self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT
) -> List[ChatMessage]:
component_info = ComponentInfo.from_component(self)
chunks: List[StreamingChunk] = []
chunk = None
chunk_delta: StreamingChunk
async for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info)
chunks.append(chunk_delta)
await callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
if meta["finish_reason"] == "length":
def _check_finish_reason(meta: Dict[str, Any]) -> None:
if meta["finish_reason"] == "length":
logger.warning(
"The completion for index {index} has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
index=meta["index"],
finish_reason=meta["finish_reason"],
)
if meta["finish_reason"] == "content_filter":
logger.warning(
"The completion for index {index} has been truncated due to the content filter.",
index=meta["index"],
finish_reason=meta["finish_reason"],
)
def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
:param chunks: The list of all `StreamingChunk` objects.
:returns: The ChatMessage.
"""
text = "".join([chunk.content for chunk in chunks])
tool_calls = []
# Process tool calls if present in any chunk
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index
for chunk_payload in chunks:
tool_calls_meta = chunk_payload.meta.get("tool_calls")
if tool_calls_meta is not None:
for delta in tool_calls_meta:
# We use the index of the tool call to track it across chunks since the ID is not always provided
if delta.index not in tool_call_data:
tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""}
# Save the ID if present
if delta.id is not None:
tool_call_data[delta.index]["id"] = delta.id
if delta.function is not None:
if delta.function.name is not None:
tool_call_data[delta.index]["name"] += delta.function.name
if delta.function.arguments is not None:
tool_call_data[delta.index]["arguments"] += delta.function.arguments
# Convert accumulated tool call data into ToolCall objects
for call_data in tool_call_data.values():
try:
arguments = json.loads(call_data["arguments"])
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"The completion for index {index} has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
index=meta["index"],
finish_reason=meta["finish_reason"],
)
if meta["finish_reason"] == "content_filter":
logger.warning(
"The completion for index {index} has been truncated due to the content filter.",
index=meta["index"],
finish_reason=meta["finish_reason"],
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=call_data["id"],
_name=call_data["name"],
_arguments=call_data["arguments"],
)
def _convert_streaming_chunks_to_chat_message(
self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
# finish_reason can appear in different places so we look for the last one
finish_reasons = [
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
]
finish_reason = finish_reasons[-1] if finish_reasons else None
:param last_chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.
meta = {
"model": chunks[-1].meta.get("model"),
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available
}
:returns: The ChatMessage.
"""
text = "".join([chunk.content for chunk in chunks])
tool_calls = []
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
# Process tool calls if present in any chunk
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index
for chunk_payload in chunks:
tool_calls_meta = chunk_payload.meta.get("tool_calls")
if tool_calls_meta is not None:
for delta in tool_calls_meta:
# We use the index of the tool call to track it across chunks since the ID is not always provided
if delta.index not in tool_call_data:
tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""}
# Save the ID if present
if delta.id is not None:
tool_call_data[delta.index]["id"] = delta.id
def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: Choice) -> ChatMessage:
"""
Converts the non-streaming response from the OpenAI API to a ChatMessage.
if delta.function is not None:
if delta.function.name is not None:
tool_call_data[delta.index]["name"] += delta.function.name
if delta.function.arguments is not None:
tool_call_data[delta.index]["arguments"] += delta.function.arguments
# Convert accumulated tool call data into ToolCall objects
for call_data in tool_call_data.values():
:param completion: The completion returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The ChatMessage.
"""
message: ChatCompletionMessage = choice.message
text = message.content
tool_calls = []
if openai_tool_calls := message.tool_calls:
for openai_tc in openai_tool_calls:
arguments_str = openai_tc.function.arguments
try:
arguments = json.loads(call_data["arguments"])
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
arguments = json.loads(arguments_str)
tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=call_data["id"],
_name=call_data["name"],
_arguments=call_data["arguments"],
_id=openai_tc.id,
_name=openai_tc.function.name,
_arguments=arguments_str,
)
# finish_reason can appear in different places so we look for the last one
finish_reasons = [
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
]
finish_reason = finish_reasons[-1] if finish_reasons else None
chat_message = ChatMessage.from_assistant(
text=text,
tool_calls=tool_calls,
meta={
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": _serialize_usage(completion.usage),
},
)
return chat_message
meta = {
"model": last_chunk.model,
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available
}
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
def _convert_chat_completion_chunk_to_streaming_chunk(
chunk: ChatCompletionChunk, component_info: Optional[ComponentInfo] = None
) -> StreamingChunk:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage:
"""
Converts the non-streaming response from the OpenAI API to a ChatMessage.
:param chunk: The chunk returned by the OpenAI API.
:param completion: The completion returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The ChatMessage.
"""
message: ChatCompletionMessage = choice.message
text = message.content
tool_calls = []
if openai_tool_calls := message.tool_calls:
for openai_tc in openai_tool_calls:
arguments_str = openai_tc.function.arguments
try:
arguments = json.loads(arguments_str)
tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=openai_tc.id,
_name=openai_tc.function.name,
_arguments=arguments_str,
)
chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)
chat_message._meta.update(
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": self._serialize_usage(completion.usage),
}
)
return chat_message
def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
:param chunk: The chunk returned by the OpenAI API.
:returns:
The StreamingChunk.
"""
# get the component name and type
component_info = ComponentInfo.from_component(self)
# we stream the content of the chunk if it's not a tool or function call
# if there are no choices, return an empty chunk
if len(chunk.choices) == 0:
return StreamingChunk(
content="",
meta={"model": chunk.model, "received_at": datetime.now().isoformat()},
component_info=component_info,
)
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content, component_info=component_info)
# but save the tool calls and function call in the meta if they are present
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
chunk_message.meta.update(
{
:returns:
The StreamingChunk.
"""
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
# It is also empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
return StreamingChunk(
content="",
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
}
"usage": _serialize_usage(chunk.usage),
},
component_info=component_info,
)
return chunk_message
def _serialize_usage(self, usage):
"""Convert OpenAI usage object to serializable dict recursively"""
if hasattr(usage, "model_dump"):
return usage.model_dump()
elif hasattr(usage, "__dict__"):
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
elif isinstance(usage, dict):
return {k: self._serialize_usage(v) for k, v in usage.items()}
elif isinstance(usage, list):
return [self._serialize_usage(item) for item in usage]
else:
return usage
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(
content=content,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
"usage": _serialize_usage(chunk.usage),
},
component_info=component_info,
)
return chunk_message
def _serialize_usage(usage):
"""Convert OpenAI usage object to serializable dict recursively"""
if hasattr(usage, "model_dump"):
return usage.model_dump()
elif hasattr(usage, "__dict__"):
return {k: _serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
elif isinstance(usage, dict):
return {k: _serialize_usage(v) for k, v in usage.items()}
elif isinstance(usage, list):
return [_serialize_usage(item) for item in usage]
else:
return usage

View File

@ -3,14 +3,25 @@
# SPDX-License-Identifier: Apache-2.0
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.components.generators.chat.openai import (
_check_finish_reason,
_convert_chat_completion_chunk_to_streaming_chunk,
_convert_chat_completion_to_chat_message,
_convert_streaming_chunks_to_chat_message,
)
from haystack.dataclasses import (
ChatMessage,
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
select_streaming_callback,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
@ -230,129 +241,26 @@ class OpenAIGenerator:
num_responses = generation_kwargs.pop("n", 1)
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
component_info = ComponentInfo.from_component(self)
chunks: List[StreamingChunk] = []
last_chunk: Optional[ChatCompletionChunk] = None
for chunk in completion:
if isinstance(chunk, ChatCompletionChunk):
last_chunk = chunk
chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, # type: ignore
component_info=component_info,
)
chunks.append(chunk_delta)
streaming_callback(chunk_delta)
if chunk.choices:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
streaming_callback(chunk_delta)
assert last_chunk is not None
completions = [self._create_message_from_chunks(last_chunk, chunks)]
completions = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
elif isinstance(completion, ChatCompletion):
completions = [self._build_message(completion, choice) for choice in completion.choices]
completions = [
_convert_chat_completion_to_chat_message(completion=completion, choice=choice)
for choice in completion.choices
]
# before returning, do post-processing of the completions
for response in completions:
self._check_finish_reason(response)
_check_finish_reason(response.meta)
return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}
def _serialize_usage(self, usage):
"""Convert OpenAI usage object to serializable dict recursively"""
if hasattr(usage, "model_dump"):
return usage.model_dump()
elif hasattr(usage, "__dict__"):
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
elif isinstance(usage, dict):
return {k: self._serialize_usage(v) for k, v in usage.items()}
elif isinstance(usage, list):
return [self._serialize_usage(item) for item in usage]
else:
return usage
def _create_message_from_chunks(
self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
"""
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in streamed_chunks]))
finish_reason = streamed_chunks[-1].meta["finish_reason"]
complete_response.meta.update(
{
"model": completion_chunk.model,
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received
"usage": self._serialize_usage(completion_chunk.usage),
}
)
return complete_response
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
"""
Converts the response from the OpenAI API to a ChatMessage.
:param completion:
The completion returned by the OpenAI API.
:param choice:
The choice returned by the OpenAI API.
:returns:
The ChatMessage.
"""
# function or tools calls are not going to happen in non-chat generation
# as users can not send ChatMessage with function or tools calls
chat_message = ChatMessage.from_assistant(choice.message.content or "")
chat_message.meta.update(
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": self._serialize_usage(completion.usage),
}
)
return chat_message
@staticmethod
def _build_chunk(chunk: Any) -> StreamingChunk:
"""
Converts the response from the OpenAI API to a StreamingChunk.
:param chunk:
The chunk returned by the OpenAI API.
:returns:
The StreamingChunk.
"""
choice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
chunk_message.meta.update(
{
"model": chunk.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
}
)
return chunk_message
@staticmethod
def _check_finish_reason(message: ChatMessage) -> None:
"""
Check the `finish_reason` returned with the OpenAI completions.
If the `finish_reason` is `length`, log a warning to the user.
:param message:
The message returned by the LLM.
"""
if message.meta["finish_reason"] == "length":
logger.warning(
"The completion for index {index} has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
index=message.meta["index"],
finish_reason=message.meta["finish_reason"],
)
if message.meta["finish_reason"] == "content_filter":
logger.warning(
"The completion for index {index} has been truncated due to the content filter.",
index=message.meta["index"],
finish_reason=message.meta["finish_reason"],
)

View File

@ -24,7 +24,12 @@ from haystack.dataclasses import StreamingChunk, ComponentInfo
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import ComponentTool, Tool
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.chat.openai import (
OpenAIChatGenerator,
_check_finish_reason,
_convert_streaming_chunks_to_chat_message,
_convert_chat_completion_chunk_to_streaming_chunk,
)
from haystack.tools.toolset import Toolset
@ -429,7 +434,6 @@ class TestOpenAIChatGenerator:
def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO)
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
messages = [
ChatMessage.from_assistant(
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
@ -438,7 +442,7 @@ class TestOpenAIChatGenerator:
]
for m in messages:
component._check_finish_reason(m.meta)
_check_finish_reason(m.meta)
# check truncation warning
message_template = (
@ -595,7 +599,6 @@ class TestOpenAIChatGenerator:
assert message.meta["usage"]["completion_tokens"] == 47
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
chunk = chat_completion_chunk.ChatCompletionChunk(
id="chatcmpl-B2g1XYv1WzALulC5c8uLtJgvEB48I",
choices=[
@ -878,7 +881,7 @@ class TestOpenAIChatGenerator:
]
# Convert chunks to a chat message
result = component._convert_streaming_chunks_to_chat_message(chunk, chunks)
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
assert not result.texts
assert not result.text
@ -899,7 +902,6 @@ class TestOpenAIChatGenerator:
assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076"
def test_convert_usage_chunk_to_streaming_chunk(self):
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
chunk = ChatCompletionChunk(
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
choices=[],
@ -918,7 +920,7 @@ class TestOpenAIChatGenerator:
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)
result = component._convert_chat_completion_chunk_to_streaming_chunk(chunk)
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk)
assert result.content == ""
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None

View File

@ -204,35 +204,6 @@ class TestOpenAIGenerator:
assert len(response["replies"]) == 1
assert [isinstance(reply, str) for reply in response["replies"]]
def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO)
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
# underlying implementation uses ChatMessage objects so we have to use them here
messages: List[ChatMessage] = []
for i, _ in enumerate(range(4)):
message = ChatMessage.from_assistant("Hello")
metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
message.meta.update(metadata)
messages.append(message)
for m in messages:
component._check_finish_reason(m)
# check truncation warning
message_template = (
"The completion for index {index} has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)
for index in [1, 3]:
assert caplog.records[index].message == message_template.format(index=index)
# check content filter warning
message_template = "The completion for index {index} has been truncated due to the content filter."
for index in [0, 2]:
assert caplog.records[index].message == message_template.format(index=index)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",