From 81c0cefa41d6d16a1811ff67d11b58160d86e90d Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Tue, 27 May 2025 15:55:06 +0200 Subject: [PATCH] refactor: Refactor hf api chat generator (#9449) * Refactor HFAPI Chat Generator * Add component info to generators * Fix type hint * Add reno * Fix unit tests * Remove incorrect dev comment * Move _convert_streaming_chunks_to_chat_message to utils file --- .../generators/chat/hugging_face_api.py | 134 +++----- .../generators/chat/hugging_face_local.py | 5 +- haystack/components/generators/chat/openai.py | 64 +--- .../components/generators/hugging_face_api.py | 5 +- .../generators/hugging_face_local.py | 9 +- haystack/components/generators/openai.py | 2 +- haystack/components/generators/utils.py | 71 +++- ...reaming-chunk-hf-api-7eba8fdf6e4fa411.yaml | 5 + .../generators/chat/test_hugging_face_api.py | 76 +++++ .../components/generators/chat/test_openai.py | 306 +----------------- test/components/generators/test_utils.py | 291 +++++++++++++++++ 11 files changed, 504 insertions(+), 464 deletions(-) create mode 100644 releasenotes/notes/update-streaming-chunk-hf-api-7eba8fdf6e4fa411.yaml create mode 100644 test/components/generators/test_utils.py diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 5df577985..28dc674b3 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -7,6 +7,7 @@ from datetime import datetime from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.lazy_imports import LazyImport @@ -101,6 +102,35 @@ def _convert_tools_to_hfapi_tools( return hf_tools +def _convert_chat_completion_stream_output_to_streaming_chunk( + chunk: "ChatCompletionStreamOutput", component_info: Optional[ComponentInfo] = None +) -> StreamingChunk: + """ + Converts the Hugging Face API ChatCompletionStreamOutput to a StreamingChunk. + """ + # Choices is empty if include_usage is set to True where the usage information is returned. + if len(chunk.choices) == 0: + usage = None + if chunk.usage: + usage = {"prompt_tokens": chunk.usage.prompt_tokens, "completion_tokens": chunk.usage.completion_tokens} + return StreamingChunk( + content="", + meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "usage": usage}, + component_info=component_info, + ) + + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = chunk.choices[0] + stream_chunk = StreamingChunk( + content=choice.delta.content or "", + meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason}, + component_info=component_info, + ) + return stream_chunk + + @component class HuggingFaceAPIChatGenerator: """ @@ -403,55 +433,19 @@ class HuggingFaceAPIChatGenerator: **generation_kwargs, ) - generated_text = "" - first_chunk_time = None - finish_reason = None - usage = None - meta: Dict[str, Any] = {} - - # get the component name and type component_info = ComponentInfo.from_component(self) - - # Set up streaming handler + streaming_chunks = [] for chunk in api_output: - # The chunk with usage returns an empty array for choices - if len(chunk.choices) > 0: - # n is unused, so the API always returns only one choice - # the argument is probably allowed for compatibility with OpenAI - # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n - choice = chunk.choices[0] + streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk( + chunk=chunk, component_info=component_info + ) + streaming_chunks.append(streaming_chunk) + streaming_callback(streaming_chunk) - text = choice.delta.content or "" - generated_text += text + message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks) + if message.meta.get("usage") is None: + message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0} - if choice.finish_reason: - finish_reason = choice.finish_reason - - stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info) - streaming_callback(stream_chunk) - - if chunk.usage: - usage = chunk.usage - - if first_chunk_time is None: - first_chunk_time = datetime.now().isoformat() - - if usage: - usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens} - else: - usage_dict = {"prompt_tokens": 0, "completion_tokens": 0} - - meta.update( - { - "model": self._client.model, - "index": 0, - "finish_reason": finish_reason, - "usage": usage_dict, - "completion_start_time": first_chunk_time, - } - ) - - message = ChatMessage.from_assistant(text=generated_text, meta=meta) return {"replies": [message]} def _run_non_streaming( @@ -503,51 +497,19 @@ class HuggingFaceAPIChatGenerator: **generation_kwargs, ) - generated_text = "" - first_chunk_time = None - finish_reason = None - usage = None - meta: Dict[str, Any] = {} - - # get the component name and type component_info = ComponentInfo.from_component(self) - + streaming_chunks = [] async for chunk in api_output: - # The chunk with usage returns an empty array for choices - if len(chunk.choices) > 0: - # n is unused, so the API always returns only one choice - # the argument is probably allowed for compatibility with OpenAI - # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n - choice = chunk.choices[0] + stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk( + chunk=chunk, component_info=component_info + ) + streaming_chunks.append(stream_chunk) + await streaming_callback(stream_chunk) # type: ignore - text = choice.delta.content or "" - generated_text += text + message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks) + if message.meta.get("usage") is None: + message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0} - stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info) - await streaming_callback(stream_chunk) # type: ignore - - if chunk.usage: - usage = chunk.usage - - if first_chunk_time is None: - first_chunk_time = datetime.now().isoformat() - - if usage: - usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens} - else: - usage_dict = {"prompt_tokens": 0, "completion_tokens": 0} - - meta.update( - { - "model": self._async_client.model, - "index": 0, - "finish_reason": finish_reason, - "usage": usage_dict, - "completion_start_time": first_chunk_time, - } - ) - - message = ChatMessage.from_assistant(text=generated_text, meta=meta) return {"replies": [message]} async def _run_non_streaming_async( diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 71a6989b8..52a990d79 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -389,7 +389,10 @@ class HuggingFaceLocalChatGenerator: component_info = ComponentInfo.from_component(self) # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming generation_kwargs["streamer"] = HFTokenStreamingHandler( - tokenizer, streaming_callback, stop_words, component_info + tokenizer=tokenizer, + stream_handler=streaming_callback, + stop_words=stop_words, + component_info=component_info, ) # convert messages to HF format diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 13bbfbe55..9b8a37484 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -13,6 +13,7 @@ from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.dataclasses import ( AsyncStreamingCallbackT, ChatMessage, @@ -455,69 +456,6 @@ def _check_finish_reason(meta: Dict[str, Any]) -> None: ) -def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage: - """ - Connects the streaming chunks into a single ChatMessage. - - :param chunks: The list of all `StreamingChunk` objects. - - :returns: The ChatMessage. - """ - text = "".join([chunk.content for chunk in chunks]) - tool_calls = [] - - # Process tool calls if present in any chunk - tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index - for chunk_payload in chunks: - tool_calls_meta = chunk_payload.meta.get("tool_calls") - if tool_calls_meta is not None: - for delta in tool_calls_meta: - # We use the index of the tool call to track it across chunks since the ID is not always provided - if delta.index not in tool_call_data: - tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""} - - # Save the ID if present - if delta.id is not None: - tool_call_data[delta.index]["id"] = delta.id - - if delta.function is not None: - if delta.function.name is not None: - tool_call_data[delta.index]["name"] += delta.function.name - if delta.function.arguments is not None: - tool_call_data[delta.index]["arguments"] += delta.function.arguments - - # Convert accumulated tool call data into ToolCall objects - for call_data in tool_call_data.values(): - try: - arguments = json.loads(call_data["arguments"]) - tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) - except json.JSONDecodeError: - logger.warning( - "OpenAI returned a malformed JSON string for tool call arguments. This tool call " - "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " - "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", - _id=call_data["id"], - _name=call_data["name"], - _arguments=call_data["arguments"], - ) - - # finish_reason can appear in different places so we look for the last one - finish_reasons = [ - chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None - ] - finish_reason = finish_reasons[-1] if finish_reasons else None - - meta = { - "model": chunks[-1].meta.get("model"), - "index": 0, - "finish_reason": finish_reason, - "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received - "usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available - } - - return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) - - def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: Choice) -> ChatMessage: """ Converts the non-streaming response from the OpenAI API to a ChatMessage. diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index a02ac6f9a..f30d37ce2 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import Any, Dict, Iterable, List, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import StreamingCallbackT, StreamingChunk, select_streaming_callback +from haystack.dataclasses import ComponentInfo, StreamingCallbackT, StreamingChunk, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model @@ -220,6 +220,7 @@ class HuggingFaceAPIGenerator: chunks: List[StreamingChunk] = [] first_chunk_time = None + component_info = ComponentInfo.from_component(self) for chunk in hf_output: token: TextGenerationStreamOutputToken = chunk.token if token.special: @@ -229,7 +230,7 @@ class HuggingFaceAPIGenerator: if first_chunk_time is None: first_chunk_time = datetime.now().isoformat() - stream_chunk = StreamingChunk(token.text, chunk_metadata) + stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata, component_info=component_info) chunks.append(stream_chunk) streaming_callback(stream_chunk) diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 5a37a9422..dc948b0f3 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import StreamingCallbackT, select_streaming_callback +from haystack.dataclasses import ComponentInfo, StreamingCallbackT, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.utils import ( ComponentDevice, @@ -256,9 +256,10 @@ class HuggingFaceLocalGenerator: updated_generation_kwargs["num_return_sequences"] = 1 # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming updated_generation_kwargs["streamer"] = HFTokenStreamingHandler( - self.pipeline.tokenizer, # type: ignore - streaming_callback, - self.stop_words, # type: ignore + tokenizer=self.pipeline.tokenizer, # type: ignore + stream_handler=streaming_callback, + stop_words=self.stop_words, # type: ignore + component_info=ComponentInfo.from_component(self), ) output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 5e1c35ad3..e83ce6b0a 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -13,8 +13,8 @@ from haystack.components.generators.chat.openai import ( _check_finish_reason, _convert_chat_completion_chunk_to_streaming_chunk, _convert_chat_completion_to_chat_message, - _convert_streaming_chunks_to_chat_message, ) +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.dataclasses import ( ChatMessage, ComponentInfo, diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 2cca078bb..33fd3cb5b 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -2,11 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict +import json +from typing import Any, Dict, List from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall -from haystack.dataclasses import StreamingChunk +from haystack import logging +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall + +logger = logging.getLogger(__name__) def print_streaming_chunk(chunk: StreamingChunk) -> None: @@ -53,3 +57,66 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: # This ensures spacing between multiple LLM messages (e.g. Agent) if chunk.meta.get("finish_reason") is not None: print("\n\n", flush=True, end="") + + +def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage: + """ + Connects the streaming chunks into a single ChatMessage. + + :param chunks: The list of all `StreamingChunk` objects. + + :returns: The ChatMessage. + """ + text = "".join([chunk.content for chunk in chunks]) + tool_calls = [] + + # Process tool calls if present in any chunk + tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index + for chunk_payload in chunks: + tool_calls_meta = chunk_payload.meta.get("tool_calls") + if tool_calls_meta is not None: + for delta in tool_calls_meta: + # We use the index of the tool call to track it across chunks since the ID is not always provided + if delta.index not in tool_call_data: + tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""} + + # Save the ID if present + if delta.id is not None: + tool_call_data[delta.index]["id"] = delta.id + + if delta.function is not None: + if delta.function.name is not None: + tool_call_data[delta.index]["name"] += delta.function.name + if delta.function.arguments is not None: + tool_call_data[delta.index]["arguments"] += delta.function.arguments + + # Convert accumulated tool call data into ToolCall objects + for call_data in tool_call_data.values(): + try: + arguments = json.loads(call_data["arguments"]) + tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "OpenAI returned a malformed JSON string for tool call arguments. This tool call " + "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=call_data["id"], + _name=call_data["name"], + _arguments=call_data["arguments"], + ) + + # finish_reason can appear in different places so we look for the last one + finish_reasons = [ + chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None + ] + finish_reason = finish_reasons[-1] if finish_reasons else None + + meta = { + "model": chunks[-1].meta.get("model"), + "index": 0, + "finish_reason": finish_reason, + "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received + "usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available + } + + return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) diff --git a/releasenotes/notes/update-streaming-chunk-hf-api-7eba8fdf6e4fa411.yaml b/releasenotes/notes/update-streaming-chunk-hf-api-7eba8fdf6e4fa411.yaml new file mode 100644 index 000000000..e17d2dc17 --- /dev/null +++ b/releasenotes/notes/update-streaming-chunk-hf-api-7eba8fdf6e4fa411.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + - Refactors the HuggingFaceAPIChatGenerator to use the util method `_convert_streaming_chunks_to_chat_message`. This is to help with being consistent for how we convert StreamingChunks into a final ChatMessage. + - We also add ComponentInfo to the StreamingChunks made in `HuggingFaceGenerator`, and `HugginFaceLocalGenerator` so we can tell from which component a stream is coming from. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 0ed1b3eb0..bd8eda1e1 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -24,6 +24,7 @@ from huggingface_hub import ( ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ChatCompletionInputStreamOptions, + ChatCompletionStreamOutputUsage, ) from huggingface_hub.errors import RepositoryNotFoundError @@ -31,6 +32,7 @@ from haystack.components.generators.chat.hugging_face_api import ( HuggingFaceAPIChatGenerator, _convert_hfapi_tool_calls, _convert_tools_to_hfapi_tools, + _convert_chat_completion_stream_output_to_streaming_chunk, ) from haystack.tools import Tool @@ -661,6 +663,80 @@ class TestHuggingFaceAPIChatGenerator: tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 0 + @pytest.mark.parametrize( + "hf_stream_output, expected_stream_chunk", + [ + ( + ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(role="assistant", content=" Paris"), index=0 + ) + ], + created=1748339326, + id="", + model="microsoft/Phi-3.5-mini-instruct", + system_fingerprint="3.2.1-sha-4d28897", + ), + StreamingChunk( + content=" Paris", + meta={ + "received_at": "2025-05-27T12:14:28.228852", + "model": "microsoft/Phi-3.5-mini-instruct", + "finish_reason": None, + }, + ), + ), + ( + ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(role="assistant", content=""), + index=0, + finish_reason="stop", + ) + ], + created=1748339326, + id="", + model="microsoft/Phi-3.5-mini-instruct", + system_fingerprint="3.2.1-sha-4d28897", + ), + StreamingChunk( + content="", + meta={ + "received_at": "2025-05-27T12:14:28.228852", + "model": "microsoft/Phi-3.5-mini-instruct", + "finish_reason": "stop", + }, + ), + ), + ( + ChatCompletionStreamOutput( + choices=[], + created=1748339326, + id="", + model="microsoft/Phi-3.5-mini-instruct", + system_fingerprint="3.2.1-sha-4d28897", + usage=ChatCompletionStreamOutputUsage(completion_tokens=2, prompt_tokens=21, total_tokens=23), + ), + StreamingChunk( + content="", + meta={ + "received_at": "2025-05-27T12:14:28.228852", + "model": "microsoft/Phi-3.5-mini-instruct", + "usage": {"completion_tokens": 2, "prompt_tokens": 21}, + }, + ), + ), + ], + ) + def test_convert_chat_completion_stream_output_to_streaming_chunk(self, hf_stream_output, expected_stream_chunk): + converted_stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(chunk=hf_stream_output) + # Remove timestamp from comparison since it's always the current time + converted_stream_chunk.meta.pop("received_at", None) + expected_stream_chunk.meta.pop("received_at", None) + assert converted_stream_chunk == expected_stream_chunk + @pytest.mark.integration @pytest.mark.slow @pytest.mark.skipif( diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 73b869f9b..6639f8674 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -20,14 +20,13 @@ from openai.types.chat import chat_completion_chunk from haystack import component from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import StreamingChunk, ComponentInfo +from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools import ComponentTool, Tool from haystack.components.generators.chat.openai import ( OpenAIChatGenerator, _check_finish_reason, - _convert_streaming_chunks_to_chat_message, _convert_chat_completion_chunk_to_streaming_chunk, ) from haystack.tools.toolset import Toolset @@ -598,309 +597,6 @@ class TestOpenAIChatGenerator: assert message.meta["finish_reason"] == "tool_calls" assert message.meta["usage"]["completion_tokens"] == 47 - def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): - chunk = chat_completion_chunk.ChatCompletionChunk( - id="chatcmpl-B2g1XYv1WzALulC5c8uLtJgvEB48I", - choices=[ - chat_completion_chunk.Choice( - delta=chat_completion_chunk.ChoiceDelta( - content=None, function_call=None, refusal=None, role=None, tool_calls=None - ), - finish_reason="tool_calls", - index=0, - logprobs=None, - ) - ], - created=1739977895, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_00428b782a", - usage=None, - ) - chunks = [ - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": None, - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.910076", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id="call_ZOj5l67zhZOx6jqjg7ATQwb6", - function=chat_completion_chunk.ChoiceDeltaToolCallFunction( - arguments="", name="rag_pipeline_tool" - ), - type="function", - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.913919", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"qu', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.914439", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ery":', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.924146", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments=' "Wher', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.924420", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="e do", name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.944398", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="es Ma", name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.944958", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="rk liv", name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.945507", - }, - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='e?"}', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.946018", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id="call_STxsYY69wVOvxWqopAt3uWTB", - function=chat_completion_chunk.ChoiceDeltaToolCallFunction( - arguments="", name="get_weather" - ), - type="function", - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.946578", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"ci', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.946981", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ty": ', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.947411", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='"Berli', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.947643", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='n"}', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.947939", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": None, - "finish_reason": "tool_calls", - "received_at": "2025-02-19T16:02:55.948772", - }, - component_info=ComponentInfo(name="test", type="test"), - ), - ] - - # Convert chunks to a chat message - result = _convert_streaming_chunks_to_chat_message(chunks=chunks) - - assert not result.texts - assert not result.text - - # Verify both tool calls were found and processed - assert len(result.tool_calls) == 2 - assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6" - assert result.tool_calls[0].tool_name == "rag_pipeline_tool" - assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"} - assert result.tool_calls[1].id == "call_STxsYY69wVOvxWqopAt3uWTB" - assert result.tool_calls[1].tool_name == "get_weather" - assert result.tool_calls[1].arguments == {"city": "Berlin"} - - # Verify meta information - assert result.meta["model"] == "gpt-4o-mini-2024-07-18" - assert result.meta["finish_reason"] == "tool_calls" - assert result.meta["index"] == 0 - assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076" - def test_convert_usage_chunk_to_streaming_chunk(self): chunk = ChatCompletionChunk( id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw", diff --git a/test/components/generators/test_utils.py b/test/components/generators/test_utils.py new file mode 100644 index 000000000..0208c7702 --- /dev/null +++ b/test/components/generators/test_utils.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from openai.types.chat import chat_completion_chunk + +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message +from haystack.dataclasses import ComponentInfo, StreamingChunk + + +def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): + chunks = [ + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": None, + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.910076", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id="call_ZOj5l67zhZOx6jqjg7ATQwb6", + function=chat_completion_chunk.ChoiceDeltaToolCallFunction( + arguments="", name="rag_pipeline_tool" + ), + type="function", + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.913919", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"qu', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.914439", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ery":', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.924146", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments=' "Wher', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.924420", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="e do", name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.944398", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="es Ma", name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.944958", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="rk liv", name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.945507", + }, + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='e?"}', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.946018", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=1, + id="call_STxsYY69wVOvxWqopAt3uWTB", + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="", name="get_weather"), + type="function", + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.946578", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=1, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"ci', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.946981", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=1, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ty": ', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.947411", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=1, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='"Berli', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.947643", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + chat_completion_chunk.ChoiceDeltaToolCall( + index=1, + id=None, + function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='n"}', name=None), + type=None, + ) + ], + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.947939", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": None, + "finish_reason": "tool_calls", + "received_at": "2025-02-19T16:02:55.948772", + }, + component_info=ComponentInfo(name="test", type="test"), + ), + ] + + # Convert chunks to a chat message + result = _convert_streaming_chunks_to_chat_message(chunks=chunks) + + assert not result.texts + assert not result.text + + # Verify both tool calls were found and processed + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6" + assert result.tool_calls[0].tool_name == "rag_pipeline_tool" + assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"} + assert result.tool_calls[1].id == "call_STxsYY69wVOvxWqopAt3uWTB" + assert result.tool_calls[1].tool_name == "get_weather" + assert result.tool_calls[1].arguments == {"city": "Berlin"} + + # Verify meta information + assert result.meta["model"] == "gpt-4o-mini-2024-07-18" + assert result.meta["finish_reason"] == "tool_calls" + assert result.meta["index"] == 0 + assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076"