diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index a9f12e7f6..7a34a5aa3 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -18,6 +18,7 @@ from haystack.dataclasses import ( ToolCall, select_streaming_callback, ) +from haystack.dataclasses.streaming_chunk import FinishReason from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, @@ -41,6 +42,7 @@ with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\" ChatCompletionOutput, ChatCompletionOutputToolCall, ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, InferenceClient, ) @@ -110,6 +112,43 @@ def _convert_tools_to_hfapi_tools( return hf_tools +def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]: + """ + Map HuggingFace finish reasons to Haystack FinishReason literals. + + Uses the full choice object to detect tool calls and provide accurate mapping. + + HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under + FinishReason): + - "length": number of generated tokens == `max_new_tokens` + - "eos_token": the model generated its end of sequence token + - "stop_sequence": the model generated a text included in `stop_sequences` + + Additionally detects tool calls from delta.tool_calls or delta.tool_call_id. + + :param choice: The HuggingFace ChatCompletionStreamOutputChoice object. + :returns: The corresponding Haystack FinishReason or None. + """ + if choice.finish_reason is None: + return None + + # Check if this choice contains tool call information + has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None + + # If we detect tool calls, override the finish reason + if has_tool_calls: + return "tool_calls" + + # Map HuggingFace finish reasons to Haystack standard ones + mapping: Dict[str, FinishReason] = { + "length": "length", # Direct match + "eos_token": "stop", # EOS token means natural stop + "stop_sequence": "stop", # Stop sequence means natural stop + } + + return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons + + def _convert_chat_completion_stream_output_to_streaming_chunk( chunk: "ChatCompletionStreamOutput", previous_chunks: List[StreamingChunk], @@ -133,6 +172,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk( # the argument is probably allowed for compatibility with OpenAI # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n choice = chunk.choices[0] + mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None stream_chunk = StreamingChunk( content=choice.delta.content or "", meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason}, @@ -141,6 +181,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk( index=0 if choice.finish_reason is None else None, # start is True at the very beginning since first chunk contains role information + first part of the answer. start=len(previous_chunks) == 0, + finish_reason=mapped_finish_reason, ) return stream_chunk diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index dd3d7cee2..b9e0988df 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -18,6 +18,7 @@ from haystack.dataclasses import ( AsyncStreamingCallbackT, ChatMessage, ComponentInfo, + FinishReason, StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, @@ -517,8 +518,15 @@ def _convert_chat_completion_chunk_to_streaming_chunk( generated the chunk, such as the component name and type. :returns: - A list of StreamingChunk objects representing the content of the chunk from the OpenAI API. + A StreamingChunk object representing the content of the chunk from the OpenAI API. """ + finish_reason_mapping: Dict[str, FinishReason] = { + "stop": "stop", + "length": "length", + "content_filter": "content_filter", + "tool_calls": "tool_calls", + "function_call": "tool_calls", + } # On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant") # Choices is empty if include_usage is set to True where the usage information is returned. if len(chunk.choices) == 0: @@ -527,6 +535,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk( component_info=component_info, # Index is None since it's only set to an int when a content block is present index=None, + finish_reason=None, meta={ "model": chunk.model, "received_at": datetime.now().isoformat(), @@ -556,6 +565,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk( index=tool_calls_deltas[0].index, tool_calls=tool_calls_deltas, start=tool_calls_deltas[0].tool_name is not None, + finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None, meta={ "model": chunk.model, "index": choice.index, @@ -584,6 +594,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk( # The first chunk is always a start message chunk that only contains role information, so if we reach here # and previous_chunks is length 1 then this is the start of text content. start=len(previous_chunks) == 1, + finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None, meta={ "model": chunk.model, "index": choice.index, diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index 89fb659be..0da42871a 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -9,6 +9,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ( ComponentInfo, + FinishReason, StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, @@ -241,8 +242,21 @@ class HuggingFaceAPIGenerator: if first_chunk_time is None: first_chunk_time = datetime.now().isoformat() + mapping: Dict[str, FinishReason] = { + "length": "length", # Direct match + "eos_token": "stop", # EOS token means natural stop + "stop_sequence": "stop", # Stop sequence means natural stop + } + mapped_finish_reason = ( + mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None + ) stream_chunk = StreamingChunk( - content=token.text, meta=chunk_metadata, component_info=component_info, index=0, start=len(chunks) == 0 + content=token.text, + meta=chunk_metadata, + component_info=component_info, + index=0, + start=len(chunks) == 0, + finish_reason=mapped_finish_reason, ) chunks.append(stream_chunk) streaming_callback(stream_chunk) diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index d8c5bd678..605d42bf0 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -65,7 +65,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: # End of LLM assistant message so we add two new lines # This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results - if chunk.meta.get("finish_reason") is not None: + if chunk.finish_reason is not None: print("\n\n", flush=True, end="") @@ -121,9 +121,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C ) # finish_reason can appear in different places so we look for the last one - finish_reasons = [ - chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None - ] + finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason] finish_reason = finish_reasons[-1] if finish_reasons else None meta = { diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 7c7dd9fae..47f2f0eca 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -553,7 +553,11 @@ class ToolInvoker: # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: - streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"})) + streaming_callback( + StreamingChunk( + content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"} + ) + ) return {"tool_messages": tool_messages, "state": state} @@ -685,7 +689,11 @@ class ToolInvoker: # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: - await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"})) + await streaming_callback( + StreamingChunk( + content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"} + ) + ) return {"tool_messages": tool_messages, "state": state} diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 1356df131..91071227f 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -17,6 +17,7 @@ _import_structure = { "streaming_chunk": [ "AsyncStreamingCallbackT", "ComponentInfo", + "FinishReason", "StreamingCallbackT", "StreamingChunk", "SyncStreamingCallbackT", @@ -40,6 +41,7 @@ if TYPE_CHECKING: from .state import State as State from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT from .streaming_chunk import ComponentInfo as ComponentInfo + from .streaming_chunk import FinishReason as FinishReason from .streaming_chunk import StreamingCallbackT as StreamingCallbackT from .streaming_chunk import StreamingChunk as StreamingChunk from .streaming_chunk import SyncStreamingCallbackT as SyncStreamingCallbackT diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index 41abc6c03..7cd30e466 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -9,6 +9,10 @@ from haystack.core.component import Component from haystack.dataclasses.chat_message import ToolCallResult from haystack.utils.asynchronous import is_callable_async_compatible +# Type alias for standard finish_reason values following OpenAI's convention +# plus Haystack-specific value ("tool_call_results") +FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"] + @dataclass class ToolCallDelta: @@ -77,6 +81,9 @@ class StreamingChunk: chunk. :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. :param start: A boolean indicating whether this chunk marks the start of a content block. + :param finish_reason: An optional value indicating the reason the generation finished. + Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter", + plus Haystack-specific value "tool_call_results". """ content: str @@ -86,6 +93,7 @@ class StreamingChunk: tool_calls: Optional[List[ToolCallDelta]] = field(default=None) tool_call_result: Optional[ToolCallResult] = field(default=None) start: bool = field(default=False) + finish_reason: Optional[FinishReason] = field(default=None) def __post_init__(self): fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result)) diff --git a/releasenotes/notes/add-finish-reason-field-streaming-chunk-89828ec09c6e6385.yaml b/releasenotes/notes/add-finish-reason-field-streaming-chunk-89828ec09c6e6385.yaml new file mode 100644 index 000000000..fdceffe5e --- /dev/null +++ b/releasenotes/notes/add-finish-reason-field-streaming-chunk-89828ec09c6e6385.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added dedicated `finish_reason` field to `StreamingChunk` class to improve type safety and enable sophisticated streaming UI logic. The field uses a `FinishReason` type alias with standard values: "stop", "length", "tool_calls", "content_filter", plus Haystack-specific value "tool_call_results" (used by ToolInvoker to indicate tool execution completion). + - | + Updated `ToolInvoker` component to use the new `finish_reason` field when streaming tool results. The component now sets `finish_reason="tool_call_results"` in the final streaming chunk to indicate that tool execution has completed, while maintaining backward compatibility by also setting the value in `meta["finish_reason"]`. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 6698a743e..5a8be6c3c 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -711,6 +711,7 @@ class TestHuggingFaceAPIChatGenerator: "model": "microsoft/Phi-3.5-mini-instruct", "finish_reason": "stop", }, + finish_reason="stop", ), [0], ), diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index fd4c227e1..dfbb2a39d 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1143,6 +1143,7 @@ def streaming_chunks(): "received_at": ANY, "usage": None, }, + finish_reason="tool_calls", ), StreamingChunk( content="", @@ -1174,7 +1175,7 @@ class TestChatCompletionChunkConversion: chunk=openai_chunk, previous_chunks=previous_chunks ) assert stream_chunk == haystack_chunk - previous_chunks.append(openai_chunk) + previous_chunks.append(stream_chunk) def test_handle_stream_response(self, chat_completion_chunks): openai_chunks = chat_completion_chunks diff --git a/test/components/generators/test_utils.py b/test/components/generators/test_utils.py index f4307afd5..1714a4880 100644 --- a/test/components/generators/test_utils.py +++ b/test/components/generators/test_utils.py @@ -266,6 +266,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "received_at": "2025-02-19T16:02:55.948772", }, component_info=ComponentInfo(name="test", type="test"), + finish_reason="tool_calls", ), StreamingChunk( content="", diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index a21f5ae31..9e14fcac5 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -203,6 +203,28 @@ class TestToolInvoker: assert tool_call_result.origin == tool_call assert not tool_call_result.error + def test_run_with_streaming_callback_finish_reason(self, invoker): + streaming_chunks = [] + + def streaming_callback(chunk: StreamingChunk) -> None: + streaming_chunks.append(chunk) + + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + result = invoker.run(messages=[message], streaming_callback=streaming_callback) + assert "tool_messages" in result + assert len(result["tool_messages"]) == 1 + + # Check that we received streaming chunks + assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason + + # The last chunk should have finish_reason set to "tool_call_results" + final_chunk = streaming_chunks[-1] + assert final_chunk.finish_reason == "tool_call_results" + assert final_chunk.meta["finish_reason"] == "tool_call_results" + assert final_chunk.content == "" + @pytest.mark.asyncio async def test_run_async_with_streaming_callback(self, thread_executor, weather_tool): streaming_callback_called = False @@ -245,6 +267,36 @@ class TestToolInvoker: # check we called the streaming callback assert streaming_callback_called + @pytest.mark.asyncio + async def test_run_async_with_streaming_callback_finish_reason(self, thread_executor, weather_tool): + streaming_chunks = [] + + async def streaming_callback(chunk: StreamingChunk) -> None: + streaming_chunks.append(chunk) + + tool_invoker = ToolInvoker( + tools=[weather_tool], + raise_on_failure=True, + convert_result_to_json_string=False, + async_executor=thread_executor, + ) + + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + result = await tool_invoker.run_async(messages=[message], streaming_callback=streaming_callback) + assert "tool_messages" in result + assert len(result["tool_messages"]) == 1 + + # Check that we received streaming chunks + assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason + + # The last chunk should have finish_reason set to "tool_call_results" + final_chunk = streaming_chunks[-1] + assert final_chunk.finish_reason == "tool_call_results" + assert final_chunk.meta["finish_reason"] == "tool_call_results" + assert final_chunk.content == "" + def test_run_with_toolset(self, tool_set): tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False) tool_call = ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}) diff --git a/test/dataclasses/test_streaming_chunk.py b/test/dataclasses/test_streaming_chunk.py index aa7c13424..1d5363302 100644 --- a/test/dataclasses/test_streaming_chunk.py +++ b/test/dataclasses/test_streaming_chunk.py @@ -4,7 +4,7 @@ import pytest -from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall +from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall, FinishReason from haystack import component from haystack import Pipeline @@ -102,3 +102,42 @@ def test_tool_call_delta(): def test_tool_call_delta_with_missing_fields(): with pytest.raises(ValueError): _ = ToolCallDelta(id="123", index=0) + + +def test_create_chunk_with_finish_reason(): + """Test creating a chunk with the new finish_reason field.""" + chunk = StreamingChunk(content="Test content", finish_reason="stop") + + assert chunk.content == "Test content" + assert chunk.finish_reason == "stop" + assert chunk.meta == {} + + +def test_create_chunk_with_finish_reason_and_meta(): + """Test creating a chunk with both finish_reason field and meta.""" + chunk = StreamingChunk( + content="Test content", finish_reason="stop", meta={"model": "gpt-4", "usage": {"tokens": 10}} + ) + + assert chunk.content == "Test content" + assert chunk.finish_reason == "stop" + assert chunk.meta["model"] == "gpt-4" + assert chunk.meta["usage"]["tokens"] == 10 + + +def test_finish_reason_standard_values(): + """Test all standard finish_reason values including the new Haystack-specific ones.""" + standard_values = ["stop", "length", "tool_calls", "content_filter", "tool_call_results"] + + for value in standard_values: + chunk = StreamingChunk(content="Test content", finish_reason=value) + assert chunk.finish_reason == value + + +def test_finish_reason_tool_call_results(): + """Test specifically the new tool_call_results finish reason.""" + chunk = StreamingChunk(content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}) + + assert chunk.finish_reason == "tool_call_results" + assert chunk.meta["finish_reason"] == "tool_call_results" + assert chunk.content == ""