feat: Add finish_reason field to StreamingChunk (#9536)

* Initial commit

* Update deprecation version

* Improve comment

* Minor simplification

* Add reno note

* Remove deprecation warning

* Remove fallback in haystack/components/generators/utils.py

* FinishReason alphabetical import

* Add tool_call_results finish reason, adapt codebase

* Define finish_reason to be Optional[FinishReason]

* Add StreamingChunk finish_reason in HF generators

* Update reno note

* Repair merge issue

* Update tests for finish_reason

* Resolve mypy issues

* Lint issue

* Enhance HF finish_reason translation

* Remove irrlevant test

* PR comments

---------

Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2025-06-25 11:06:01 +02:00 committed by GitHub
parent 1d1c13a8bc
commit 91094e1038
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 192 additions and 10 deletions

View File

@ -18,6 +18,7 @@ from haystack.dataclasses import (
ToolCall,
select_streaming_callback,
)
from haystack.dataclasses.streaming_chunk import FinishReason
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
@ -41,6 +42,7 @@ with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"
ChatCompletionOutput,
ChatCompletionOutputToolCall,
ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice,
InferenceClient,
)
@ -110,6 +112,43 @@ def _convert_tools_to_hfapi_tools(
return hf_tools
def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]:
"""
Map HuggingFace finish reasons to Haystack FinishReason literals.
Uses the full choice object to detect tool calls and provide accurate mapping.
HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
FinishReason):
- "length": number of generated tokens == `max_new_tokens`
- "eos_token": the model generated its end of sequence token
- "stop_sequence": the model generated a text included in `stop_sequences`
Additionally detects tool calls from delta.tool_calls or delta.tool_call_id.
:param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
:returns: The corresponding Haystack FinishReason or None.
"""
if choice.finish_reason is None:
return None
# Check if this choice contains tool call information
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
# If we detect tool calls, override the finish reason
if has_tool_calls:
return "tool_calls"
# Map HuggingFace finish reasons to Haystack standard ones
mapping: Dict[str, FinishReason] = {
"length": "length", # Direct match
"eos_token": "stop", # EOS token means natural stop
"stop_sequence": "stop", # Stop sequence means natural stop
}
return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons
def _convert_chat_completion_stream_output_to_streaming_chunk(
chunk: "ChatCompletionStreamOutput",
previous_chunks: List[StreamingChunk],
@ -133,6 +172,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
stream_chunk = StreamingChunk(
content=choice.delta.content or "",
meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason},
@ -141,6 +181,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
index=0 if choice.finish_reason is None else None,
# start is True at the very beginning since first chunk contains role information + first part of the answer.
start=len(previous_chunks) == 0,
finish_reason=mapped_finish_reason,
)
return stream_chunk

View File

@ -18,6 +18,7 @@ from haystack.dataclasses import (
AsyncStreamingCallbackT,
ChatMessage,
ComponentInfo,
FinishReason,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
@ -517,8 +518,15 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
generated the chunk, such as the component name and type.
:returns:
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
A StreamingChunk object representing the content of the chunk from the OpenAI API.
"""
finish_reason_mapping: Dict[str, FinishReason] = {
"stop": "stop",
"length": "length",
"content_filter": "content_filter",
"tool_calls": "tool_calls",
"function_call": "tool_calls",
}
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
# Choices is empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
@ -527,6 +535,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
component_info=component_info,
# Index is None since it's only set to an int when a content block is present
index=None,
finish_reason=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(),
@ -556,6 +565,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
index=tool_calls_deltas[0].index,
tool_calls=tool_calls_deltas,
start=tool_calls_deltas[0].tool_name is not None,
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
meta={
"model": chunk.model,
"index": choice.index,
@ -584,6 +594,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
# The first chunk is always a start message chunk that only contains role information, so if we reach here
# and previous_chunks is length 1 then this is the start of text content.
start=len(previous_chunks) == 1,
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
meta={
"model": chunk.model,
"index": choice.index,

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union, cast
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import (
ComponentInfo,
FinishReason,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
@ -241,8 +242,21 @@ class HuggingFaceAPIGenerator:
if first_chunk_time is None:
first_chunk_time = datetime.now().isoformat()
mapping: Dict[str, FinishReason] = {
"length": "length", # Direct match
"eos_token": "stop", # EOS token means natural stop
"stop_sequence": "stop", # Stop sequence means natural stop
}
mapped_finish_reason = (
mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None
)
stream_chunk = StreamingChunk(
content=token.text, meta=chunk_metadata, component_info=component_info, index=0, start=len(chunks) == 0
content=token.text,
meta=chunk_metadata,
component_info=component_info,
index=0,
start=len(chunks) == 0,
finish_reason=mapped_finish_reason,
)
chunks.append(stream_chunk)
streaming_callback(stream_chunk)

View File

@ -65,7 +65,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
# End of LLM assistant message so we add two new lines
# This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results
if chunk.meta.get("finish_reason") is not None:
if chunk.finish_reason is not None:
print("\n\n", flush=True, end="")
@ -121,9 +121,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
)
# finish_reason can appear in different places so we look for the last one
finish_reasons = [
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
]
finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason]
finish_reason = finish_reasons[-1] if finish_reasons else None
meta = {

View File

@ -553,7 +553,11 @@ class ToolInvoker:
# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
streaming_callback(
StreamingChunk(
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
)
)
return {"tool_messages": tool_messages, "state": state}
@ -685,7 +689,11 @@ class ToolInvoker:
# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
await streaming_callback(
StreamingChunk(
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
)
)
return {"tool_messages": tool_messages, "state": state}

View File

@ -17,6 +17,7 @@ _import_structure = {
"streaming_chunk": [
"AsyncStreamingCallbackT",
"ComponentInfo",
"FinishReason",
"StreamingCallbackT",
"StreamingChunk",
"SyncStreamingCallbackT",
@ -40,6 +41,7 @@ if TYPE_CHECKING:
from .state import State as State
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
from .streaming_chunk import ComponentInfo as ComponentInfo
from .streaming_chunk import FinishReason as FinishReason
from .streaming_chunk import StreamingCallbackT as StreamingCallbackT
from .streaming_chunk import StreamingChunk as StreamingChunk
from .streaming_chunk import SyncStreamingCallbackT as SyncStreamingCallbackT

View File

@ -9,6 +9,10 @@ from haystack.core.component import Component
from haystack.dataclasses.chat_message import ToolCallResult
from haystack.utils.asynchronous import is_callable_async_compatible
# Type alias for standard finish_reason values following OpenAI's convention
# plus Haystack-specific value ("tool_call_results")
FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
@dataclass
class ToolCallDelta:
@ -77,6 +81,9 @@ class StreamingChunk:
chunk.
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
:param start: A boolean indicating whether this chunk marks the start of a content block.
:param finish_reason: An optional value indicating the reason the generation finished.
Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter",
plus Haystack-specific value "tool_call_results".
"""
content: str
@ -86,6 +93,7 @@ class StreamingChunk:
tool_calls: Optional[List[ToolCallDelta]] = field(default=None)
tool_call_result: Optional[ToolCallResult] = field(default=None)
start: bool = field(default=False)
finish_reason: Optional[FinishReason] = field(default=None)
def __post_init__(self):
fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result))

View File

@ -0,0 +1,6 @@
---
features:
- |
Added dedicated `finish_reason` field to `StreamingChunk` class to improve type safety and enable sophisticated streaming UI logic. The field uses a `FinishReason` type alias with standard values: "stop", "length", "tool_calls", "content_filter", plus Haystack-specific value "tool_call_results" (used by ToolInvoker to indicate tool execution completion).
- |
Updated `ToolInvoker` component to use the new `finish_reason` field when streaming tool results. The component now sets `finish_reason="tool_call_results"` in the final streaming chunk to indicate that tool execution has completed, while maintaining backward compatibility by also setting the value in `meta["finish_reason"]`.

View File

@ -711,6 +711,7 @@ class TestHuggingFaceAPIChatGenerator:
"model": "microsoft/Phi-3.5-mini-instruct",
"finish_reason": "stop",
},
finish_reason="stop",
),
[0],
),

View File

@ -1143,6 +1143,7 @@ def streaming_chunks():
"received_at": ANY,
"usage": None,
},
finish_reason="tool_calls",
),
StreamingChunk(
content="",
@ -1174,7 +1175,7 @@ class TestChatCompletionChunkConversion:
chunk=openai_chunk, previous_chunks=previous_chunks
)
assert stream_chunk == haystack_chunk
previous_chunks.append(openai_chunk)
previous_chunks.append(stream_chunk)
def test_handle_stream_response(self, chat_completion_chunks):
openai_chunks = chat_completion_chunks

View File

@ -266,6 +266,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
"received_at": "2025-02-19T16:02:55.948772",
},
component_info=ComponentInfo(name="test", type="test"),
finish_reason="tool_calls",
),
StreamingChunk(
content="",

View File

@ -203,6 +203,28 @@ class TestToolInvoker:
assert tool_call_result.origin == tool_call
assert not tool_call_result.error
def test_run_with_streaming_callback_finish_reason(self, invoker):
streaming_chunks = []
def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = invoker.run(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result
assert len(result["tool_messages"]) == 1
# Check that we received streaming chunks
assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason
# The last chunk should have finish_reason set to "tool_call_results"
final_chunk = streaming_chunks[-1]
assert final_chunk.finish_reason == "tool_call_results"
assert final_chunk.meta["finish_reason"] == "tool_call_results"
assert final_chunk.content == ""
@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, thread_executor, weather_tool):
streaming_callback_called = False
@ -245,6 +267,36 @@ class TestToolInvoker:
# check we called the streaming callback
assert streaming_callback_called
@pytest.mark.asyncio
async def test_run_async_with_streaming_callback_finish_reason(self, thread_executor, weather_tool):
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
tool_invoker = ToolInvoker(
tools=[weather_tool],
raise_on_failure=True,
convert_result_to_json_string=False,
async_executor=thread_executor,
)
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = await tool_invoker.run_async(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result
assert len(result["tool_messages"]) == 1
# Check that we received streaming chunks
assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason
# The last chunk should have finish_reason set to "tool_call_results"
final_chunk = streaming_chunks[-1]
assert final_chunk.finish_reason == "tool_call_results"
assert final_chunk.meta["finish_reason"] == "tool_call_results"
assert final_chunk.content == ""
def test_run_with_toolset(self, tool_set):
tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False)
tool_call = ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3})

View File

@ -4,7 +4,7 @@
import pytest
from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall
from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall, FinishReason
from haystack import component
from haystack import Pipeline
@ -102,3 +102,42 @@ def test_tool_call_delta():
def test_tool_call_delta_with_missing_fields():
with pytest.raises(ValueError):
_ = ToolCallDelta(id="123", index=0)
def test_create_chunk_with_finish_reason():
"""Test creating a chunk with the new finish_reason field."""
chunk = StreamingChunk(content="Test content", finish_reason="stop")
assert chunk.content == "Test content"
assert chunk.finish_reason == "stop"
assert chunk.meta == {}
def test_create_chunk_with_finish_reason_and_meta():
"""Test creating a chunk with both finish_reason field and meta."""
chunk = StreamingChunk(
content="Test content", finish_reason="stop", meta={"model": "gpt-4", "usage": {"tokens": 10}}
)
assert chunk.content == "Test content"
assert chunk.finish_reason == "stop"
assert chunk.meta["model"] == "gpt-4"
assert chunk.meta["usage"]["tokens"] == 10
def test_finish_reason_standard_values():
"""Test all standard finish_reason values including the new Haystack-specific ones."""
standard_values = ["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
for value in standard_values:
chunk = StreamingChunk(content="Test content", finish_reason=value)
assert chunk.finish_reason == value
def test_finish_reason_tool_call_results():
"""Test specifically the new tool_call_results finish reason."""
chunk = StreamingChunk(content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"})
assert chunk.finish_reason == "tool_call_results"
assert chunk.meta["finish_reason"] == "tool_call_results"
assert chunk.content == ""