mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
refactor: Refactor openai generator (#9445)
* Refactor openai generator and chat generator to reusue same util methods * Start fixing tests * More fixes * Fix mypy * Fix
This commit is contained in:
parent
64def6d41b
commit
db3d95b12a
@ -285,13 +285,12 @@ class OpenAIChatGenerator:
|
||||
else:
|
||||
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
|
||||
completions = [
|
||||
self._convert_chat_completion_to_chat_message(chat_completion, choice)
|
||||
for choice in chat_completion.choices
|
||||
_convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices
|
||||
]
|
||||
|
||||
# before returning, do post-processing of the completions
|
||||
for message in completions:
|
||||
self._check_finish_reason(message.meta)
|
||||
_check_finish_reason(message.meta)
|
||||
|
||||
return {"replies": completions}
|
||||
|
||||
@ -362,13 +361,12 @@ class OpenAIChatGenerator:
|
||||
else:
|
||||
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
|
||||
completions = [
|
||||
self._convert_chat_completion_to_chat_message(chat_completion, choice)
|
||||
for choice in chat_completion.choices
|
||||
_convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices
|
||||
]
|
||||
|
||||
# before returning, do post-processing of the completions
|
||||
for message in completions:
|
||||
self._check_finish_reason(message.meta)
|
||||
_check_finish_reason(message.meta)
|
||||
|
||||
return {"replies": completions}
|
||||
|
||||
@ -419,197 +417,198 @@ class OpenAIChatGenerator:
|
||||
}
|
||||
|
||||
def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]:
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
chunks: List[StreamingChunk] = []
|
||||
chunk = None
|
||||
chunk_delta: StreamingChunk
|
||||
|
||||
for chunk in chat_completion: # pylint: disable=not-an-iterable
|
||||
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
|
||||
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
|
||||
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info)
|
||||
chunks.append(chunk_delta)
|
||||
callback(chunk_delta)
|
||||
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
|
||||
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
|
||||
|
||||
async def _handle_async_stream_response(
|
||||
self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT
|
||||
) -> List[ChatMessage]:
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
chunks: List[StreamingChunk] = []
|
||||
chunk = None
|
||||
chunk_delta: StreamingChunk
|
||||
|
||||
async for chunk in chat_completion: # pylint: disable=not-an-iterable
|
||||
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
|
||||
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
|
||||
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info)
|
||||
chunks.append(chunk_delta)
|
||||
await callback(chunk_delta)
|
||||
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
|
||||
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
|
||||
|
||||
def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
|
||||
if meta["finish_reason"] == "length":
|
||||
|
||||
def _check_finish_reason(meta: Dict[str, Any]) -> None:
|
||||
if meta["finish_reason"] == "length":
|
||||
logger.warning(
|
||||
"The completion for index {index} has been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions.",
|
||||
index=meta["index"],
|
||||
finish_reason=meta["finish_reason"],
|
||||
)
|
||||
if meta["finish_reason"] == "content_filter":
|
||||
logger.warning(
|
||||
"The completion for index {index} has been truncated due to the content filter.",
|
||||
index=meta["index"],
|
||||
finish_reason=meta["finish_reason"],
|
||||
)
|
||||
|
||||
|
||||
def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
|
||||
:param chunks: The list of all `StreamingChunk` objects.
|
||||
|
||||
:returns: The ChatMessage.
|
||||
"""
|
||||
text = "".join([chunk.content for chunk in chunks])
|
||||
tool_calls = []
|
||||
|
||||
# Process tool calls if present in any chunk
|
||||
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index
|
||||
for chunk_payload in chunks:
|
||||
tool_calls_meta = chunk_payload.meta.get("tool_calls")
|
||||
if tool_calls_meta is not None:
|
||||
for delta in tool_calls_meta:
|
||||
# We use the index of the tool call to track it across chunks since the ID is not always provided
|
||||
if delta.index not in tool_call_data:
|
||||
tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""}
|
||||
|
||||
# Save the ID if present
|
||||
if delta.id is not None:
|
||||
tool_call_data[delta.index]["id"] = delta.id
|
||||
|
||||
if delta.function is not None:
|
||||
if delta.function.name is not None:
|
||||
tool_call_data[delta.index]["name"] += delta.function.name
|
||||
if delta.function.arguments is not None:
|
||||
tool_call_data[delta.index]["arguments"] += delta.function.arguments
|
||||
|
||||
# Convert accumulated tool call data into ToolCall objects
|
||||
for call_data in tool_call_data.values():
|
||||
try:
|
||||
arguments = json.loads(call_data["arguments"])
|
||||
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"The completion for index {index} has been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions.",
|
||||
index=meta["index"],
|
||||
finish_reason=meta["finish_reason"],
|
||||
)
|
||||
if meta["finish_reason"] == "content_filter":
|
||||
logger.warning(
|
||||
"The completion for index {index} has been truncated due to the content filter.",
|
||||
index=meta["index"],
|
||||
finish_reason=meta["finish_reason"],
|
||||
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
|
||||
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
|
||||
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
||||
_id=call_data["id"],
|
||||
_name=call_data["name"],
|
||||
_arguments=call_data["arguments"],
|
||||
)
|
||||
|
||||
def _convert_streaming_chunks_to_chat_message(
|
||||
self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
# finish_reason can appear in different places so we look for the last one
|
||||
finish_reasons = [
|
||||
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
|
||||
]
|
||||
finish_reason = finish_reasons[-1] if finish_reasons else None
|
||||
|
||||
:param last_chunk: The last chunk returned by the OpenAI API.
|
||||
:param chunks: The list of all `StreamingChunk` objects.
|
||||
meta = {
|
||||
"model": chunks[-1].meta.get("model"),
|
||||
"index": 0,
|
||||
"finish_reason": finish_reason,
|
||||
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
|
||||
"usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available
|
||||
}
|
||||
|
||||
:returns: The ChatMessage.
|
||||
"""
|
||||
text = "".join([chunk.content for chunk in chunks])
|
||||
tool_calls = []
|
||||
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
|
||||
|
||||
# Process tool calls if present in any chunk
|
||||
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index
|
||||
for chunk_payload in chunks:
|
||||
tool_calls_meta = chunk_payload.meta.get("tool_calls")
|
||||
if tool_calls_meta is not None:
|
||||
for delta in tool_calls_meta:
|
||||
# We use the index of the tool call to track it across chunks since the ID is not always provided
|
||||
if delta.index not in tool_call_data:
|
||||
tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""}
|
||||
|
||||
# Save the ID if present
|
||||
if delta.id is not None:
|
||||
tool_call_data[delta.index]["id"] = delta.id
|
||||
def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: Choice) -> ChatMessage:
|
||||
"""
|
||||
Converts the non-streaming response from the OpenAI API to a ChatMessage.
|
||||
|
||||
if delta.function is not None:
|
||||
if delta.function.name is not None:
|
||||
tool_call_data[delta.index]["name"] += delta.function.name
|
||||
if delta.function.arguments is not None:
|
||||
tool_call_data[delta.index]["arguments"] += delta.function.arguments
|
||||
|
||||
# Convert accumulated tool call data into ToolCall objects
|
||||
for call_data in tool_call_data.values():
|
||||
:param completion: The completion returned by the OpenAI API.
|
||||
:param choice: The choice returned by the OpenAI API.
|
||||
:return: The ChatMessage.
|
||||
"""
|
||||
message: ChatCompletionMessage = choice.message
|
||||
text = message.content
|
||||
tool_calls = []
|
||||
if openai_tool_calls := message.tool_calls:
|
||||
for openai_tc in openai_tool_calls:
|
||||
arguments_str = openai_tc.function.arguments
|
||||
try:
|
||||
arguments = json.loads(call_data["arguments"])
|
||||
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
|
||||
arguments = json.loads(arguments_str)
|
||||
tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
|
||||
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
|
||||
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
||||
_id=call_data["id"],
|
||||
_name=call_data["name"],
|
||||
_arguments=call_data["arguments"],
|
||||
_id=openai_tc.id,
|
||||
_name=openai_tc.function.name,
|
||||
_arguments=arguments_str,
|
||||
)
|
||||
|
||||
# finish_reason can appear in different places so we look for the last one
|
||||
finish_reasons = [
|
||||
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
|
||||
]
|
||||
finish_reason = finish_reasons[-1] if finish_reasons else None
|
||||
chat_message = ChatMessage.from_assistant(
|
||||
text=text,
|
||||
tool_calls=tool_calls,
|
||||
meta={
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": _serialize_usage(completion.usage),
|
||||
},
|
||||
)
|
||||
return chat_message
|
||||
|
||||
meta = {
|
||||
"model": last_chunk.model,
|
||||
"index": 0,
|
||||
"finish_reason": finish_reason,
|
||||
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
|
||||
"usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available
|
||||
}
|
||||
|
||||
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
|
||||
def _convert_chat_completion_chunk_to_streaming_chunk(
|
||||
chunk: ChatCompletionChunk, component_info: Optional[ComponentInfo] = None
|
||||
) -> StreamingChunk:
|
||||
"""
|
||||
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
|
||||
|
||||
def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage:
|
||||
"""
|
||||
Converts the non-streaming response from the OpenAI API to a ChatMessage.
|
||||
:param chunk: The chunk returned by the OpenAI API.
|
||||
|
||||
:param completion: The completion returned by the OpenAI API.
|
||||
:param choice: The choice returned by the OpenAI API.
|
||||
:return: The ChatMessage.
|
||||
"""
|
||||
message: ChatCompletionMessage = choice.message
|
||||
text = message.content
|
||||
tool_calls = []
|
||||
if openai_tool_calls := message.tool_calls:
|
||||
for openai_tc in openai_tool_calls:
|
||||
arguments_str = openai_tc.function.arguments
|
||||
try:
|
||||
arguments = json.loads(arguments_str)
|
||||
tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
|
||||
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
|
||||
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
||||
_id=openai_tc.id,
|
||||
_name=openai_tc.function.name,
|
||||
_arguments=arguments_str,
|
||||
)
|
||||
|
||||
chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)
|
||||
chat_message._meta.update(
|
||||
{
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": self._serialize_usage(completion.usage),
|
||||
}
|
||||
)
|
||||
return chat_message
|
||||
|
||||
def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
|
||||
"""
|
||||
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
|
||||
|
||||
:param chunk: The chunk returned by the OpenAI API.
|
||||
|
||||
:returns:
|
||||
The StreamingChunk.
|
||||
"""
|
||||
|
||||
# get the component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
|
||||
# we stream the content of the chunk if it's not a tool or function call
|
||||
# if there are no choices, return an empty chunk
|
||||
if len(chunk.choices) == 0:
|
||||
return StreamingChunk(
|
||||
content="",
|
||||
meta={"model": chunk.model, "received_at": datetime.now().isoformat()},
|
||||
component_info=component_info,
|
||||
)
|
||||
|
||||
choice: ChunkChoice = chunk.choices[0]
|
||||
content = choice.delta.content or ""
|
||||
chunk_message = StreamingChunk(content, component_info=component_info)
|
||||
|
||||
# but save the tool calls and function call in the meta if they are present
|
||||
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
|
||||
chunk_message.meta.update(
|
||||
{
|
||||
:returns:
|
||||
The StreamingChunk.
|
||||
"""
|
||||
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
|
||||
# It is also empty if include_usage is set to True where the usage information is returned.
|
||||
if len(chunk.choices) == 0:
|
||||
return StreamingChunk(
|
||||
content="",
|
||||
meta={
|
||||
"model": chunk.model,
|
||||
"index": choice.index,
|
||||
"tool_calls": choice.delta.tool_calls,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"received_at": datetime.now().isoformat(),
|
||||
}
|
||||
"usage": _serialize_usage(chunk.usage),
|
||||
},
|
||||
component_info=component_info,
|
||||
)
|
||||
return chunk_message
|
||||
|
||||
def _serialize_usage(self, usage):
|
||||
"""Convert OpenAI usage object to serializable dict recursively"""
|
||||
if hasattr(usage, "model_dump"):
|
||||
return usage.model_dump()
|
||||
elif hasattr(usage, "__dict__"):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
|
||||
elif isinstance(usage, dict):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.items()}
|
||||
elif isinstance(usage, list):
|
||||
return [self._serialize_usage(item) for item in usage]
|
||||
else:
|
||||
return usage
|
||||
choice: ChunkChoice = chunk.choices[0]
|
||||
content = choice.delta.content or ""
|
||||
|
||||
chunk_message = StreamingChunk(
|
||||
content=content,
|
||||
meta={
|
||||
"model": chunk.model,
|
||||
"index": choice.index,
|
||||
"tool_calls": choice.delta.tool_calls,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"received_at": datetime.now().isoformat(),
|
||||
"usage": _serialize_usage(chunk.usage),
|
||||
},
|
||||
component_info=component_info,
|
||||
)
|
||||
return chunk_message
|
||||
|
||||
|
||||
def _serialize_usage(usage):
|
||||
"""Convert OpenAI usage object to serializable dict recursively"""
|
||||
if hasattr(usage, "model_dump"):
|
||||
return usage.model_dump()
|
||||
elif hasattr(usage, "__dict__"):
|
||||
return {k: _serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
|
||||
elif isinstance(usage, dict):
|
||||
return {k: _serialize_usage(v) for k, v in usage.items()}
|
||||
elif isinstance(usage, list):
|
||||
return [_serialize_usage(item) for item in usage]
|
||||
else:
|
||||
return usage
|
||||
|
||||
@ -3,14 +3,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback
|
||||
from haystack.components.generators.chat.openai import (
|
||||
_check_finish_reason,
|
||||
_convert_chat_completion_chunk_to_streaming_chunk,
|
||||
_convert_chat_completion_to_chat_message,
|
||||
_convert_streaming_chunks_to_chat_message,
|
||||
)
|
||||
from haystack.dataclasses import (
|
||||
ChatMessage,
|
||||
ComponentInfo,
|
||||
StreamingCallbackT,
|
||||
StreamingChunk,
|
||||
select_streaming_callback,
|
||||
)
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
@ -230,129 +241,26 @@ class OpenAIGenerator:
|
||||
num_responses = generation_kwargs.pop("n", 1)
|
||||
if num_responses > 1:
|
||||
raise ValueError("Cannot stream multiple responses, please set n=1.")
|
||||
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
chunks: List[StreamingChunk] = []
|
||||
last_chunk: Optional[ChatCompletionChunk] = None
|
||||
|
||||
for chunk in completion:
|
||||
if isinstance(chunk, ChatCompletionChunk):
|
||||
last_chunk = chunk
|
||||
chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk(
|
||||
chunk=chunk, # type: ignore
|
||||
component_info=component_info,
|
||||
)
|
||||
chunks.append(chunk_delta)
|
||||
streaming_callback(chunk_delta)
|
||||
|
||||
if chunk.choices:
|
||||
chunk_delta: StreamingChunk = self._build_chunk(chunk)
|
||||
chunks.append(chunk_delta)
|
||||
streaming_callback(chunk_delta)
|
||||
|
||||
assert last_chunk is not None
|
||||
|
||||
completions = [self._create_message_from_chunks(last_chunk, chunks)]
|
||||
completions = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
|
||||
elif isinstance(completion, ChatCompletion):
|
||||
completions = [self._build_message(completion, choice) for choice in completion.choices]
|
||||
completions = [
|
||||
_convert_chat_completion_to_chat_message(completion=completion, choice=choice)
|
||||
for choice in completion.choices
|
||||
]
|
||||
|
||||
# before returning, do post-processing of the completions
|
||||
for response in completions:
|
||||
self._check_finish_reason(response)
|
||||
_check_finish_reason(response.meta)
|
||||
|
||||
return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}
|
||||
|
||||
def _serialize_usage(self, usage):
|
||||
"""Convert OpenAI usage object to serializable dict recursively"""
|
||||
if hasattr(usage, "model_dump"):
|
||||
return usage.model_dump()
|
||||
elif hasattr(usage, "__dict__"):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
|
||||
elif isinstance(usage, dict):
|
||||
return {k: self._serialize_usage(v) for k, v in usage.items()}
|
||||
elif isinstance(usage, list):
|
||||
return [self._serialize_usage(item) for item in usage]
|
||||
else:
|
||||
return usage
|
||||
|
||||
def _create_message_from_chunks(
|
||||
self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
|
||||
"""
|
||||
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in streamed_chunks]))
|
||||
finish_reason = streamed_chunks[-1].meta["finish_reason"]
|
||||
complete_response.meta.update(
|
||||
{
|
||||
"model": completion_chunk.model,
|
||||
"index": 0,
|
||||
"finish_reason": finish_reason,
|
||||
"completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received
|
||||
"usage": self._serialize_usage(completion_chunk.usage),
|
||||
}
|
||||
)
|
||||
return complete_response
|
||||
|
||||
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a ChatMessage.
|
||||
|
||||
:param completion:
|
||||
The completion returned by the OpenAI API.
|
||||
:param choice:
|
||||
The choice returned by the OpenAI API.
|
||||
:returns:
|
||||
The ChatMessage.
|
||||
"""
|
||||
# function or tools calls are not going to happen in non-chat generation
|
||||
# as users can not send ChatMessage with function or tools calls
|
||||
chat_message = ChatMessage.from_assistant(choice.message.content or "")
|
||||
chat_message.meta.update(
|
||||
{
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": self._serialize_usage(completion.usage),
|
||||
}
|
||||
)
|
||||
return chat_message
|
||||
|
||||
@staticmethod
|
||||
def _build_chunk(chunk: Any) -> StreamingChunk:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a StreamingChunk.
|
||||
|
||||
:param chunk:
|
||||
The chunk returned by the OpenAI API.
|
||||
:returns:
|
||||
The StreamingChunk.
|
||||
"""
|
||||
choice = chunk.choices[0]
|
||||
content = choice.delta.content or ""
|
||||
chunk_message = StreamingChunk(content)
|
||||
chunk_message.meta.update(
|
||||
{
|
||||
"model": chunk.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"received_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
return chunk_message
|
||||
|
||||
@staticmethod
|
||||
def _check_finish_reason(message: ChatMessage) -> None:
|
||||
"""
|
||||
Check the `finish_reason` returned with the OpenAI completions.
|
||||
|
||||
If the `finish_reason` is `length`, log a warning to the user.
|
||||
|
||||
:param message:
|
||||
The message returned by the LLM.
|
||||
"""
|
||||
if message.meta["finish_reason"] == "length":
|
||||
logger.warning(
|
||||
"The completion for index {index} has been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions.",
|
||||
index=message.meta["index"],
|
||||
finish_reason=message.meta["finish_reason"],
|
||||
)
|
||||
if message.meta["finish_reason"] == "content_filter":
|
||||
logger.warning(
|
||||
"The completion for index {index} has been truncated due to the content filter.",
|
||||
index=message.meta["index"],
|
||||
finish_reason=message.meta["finish_reason"],
|
||||
)
|
||||
|
||||
@ -24,7 +24,12 @@ from haystack.dataclasses import StreamingChunk, ComponentInfo
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools import ComponentTool, Tool
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.components.generators.chat.openai import (
|
||||
OpenAIChatGenerator,
|
||||
_check_finish_reason,
|
||||
_convert_streaming_chunks_to_chat_message,
|
||||
_convert_chat_completion_chunk_to_streaming_chunk,
|
||||
)
|
||||
from haystack.tools.toolset import Toolset
|
||||
|
||||
|
||||
@ -429,7 +434,6 @@ class TestOpenAIChatGenerator:
|
||||
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
messages = [
|
||||
ChatMessage.from_assistant(
|
||||
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
|
||||
@ -438,7 +442,7 @@ class TestOpenAIChatGenerator:
|
||||
]
|
||||
|
||||
for m in messages:
|
||||
component._check_finish_reason(m.meta)
|
||||
_check_finish_reason(m.meta)
|
||||
|
||||
# check truncation warning
|
||||
message_template = (
|
||||
@ -595,7 +599,6 @@ class TestOpenAIChatGenerator:
|
||||
assert message.meta["usage"]["completion_tokens"] == 47
|
||||
|
||||
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
chunk = chat_completion_chunk.ChatCompletionChunk(
|
||||
id="chatcmpl-B2g1XYv1WzALulC5c8uLtJgvEB48I",
|
||||
choices=[
|
||||
@ -878,7 +881,7 @@ class TestOpenAIChatGenerator:
|
||||
]
|
||||
|
||||
# Convert chunks to a chat message
|
||||
result = component._convert_streaming_chunks_to_chat_message(chunk, chunks)
|
||||
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
|
||||
|
||||
assert not result.texts
|
||||
assert not result.text
|
||||
@ -899,7 +902,6 @@ class TestOpenAIChatGenerator:
|
||||
assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076"
|
||||
|
||||
def test_convert_usage_chunk_to_streaming_chunk(self):
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
chunk = ChatCompletionChunk(
|
||||
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
|
||||
choices=[],
|
||||
@ -918,7 +920,7 @@ class TestOpenAIChatGenerator:
|
||||
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
|
||||
),
|
||||
)
|
||||
result = component._convert_chat_completion_chunk_to_streaming_chunk(chunk)
|
||||
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk)
|
||||
assert result.content == ""
|
||||
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
|
||||
assert result.meta["received_at"] is not None
|
||||
|
||||
@ -204,35 +204,6 @@ class TestOpenAIGenerator:
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
# underlying implementation uses ChatMessage objects so we have to use them here
|
||||
messages: List[ChatMessage] = []
|
||||
for i, _ in enumerate(range(4)):
|
||||
message = ChatMessage.from_assistant("Hello")
|
||||
metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
|
||||
message.meta.update(metadata)
|
||||
messages.append(message)
|
||||
|
||||
for m in messages:
|
||||
component._check_finish_reason(m)
|
||||
|
||||
# check truncation warning
|
||||
message_template = (
|
||||
"The completion for index {index} has been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions."
|
||||
)
|
||||
|
||||
for index in [1, 3]:
|
||||
assert caplog.records[index].message == message_template.format(index=index)
|
||||
|
||||
# check content filter warning
|
||||
message_template = "The completion for index {index} has been truncated due to the content filter."
|
||||
for index in [0, 2]:
|
||||
assert caplog.records[index].message == message_template.format(index=index)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user