mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: Add finish_reason
field to StreamingChunk
(#9536)
* Initial commit * Update deprecation version * Improve comment * Minor simplification * Add reno note * Remove deprecation warning * Remove fallback in haystack/components/generators/utils.py * FinishReason alphabetical import * Add tool_call_results finish reason, adapt codebase * Define finish_reason to be Optional[FinishReason] * Add StreamingChunk finish_reason in HF generators * Update reno note * Repair merge issue * Update tests for finish_reason * Resolve mypy issues * Lint issue * Enhance HF finish_reason translation * Remove irrlevant test * PR comments --------- Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
parent
1d1c13a8bc
commit
91094e1038
@ -18,6 +18,7 @@ from haystack.dataclasses import (
|
||||
ToolCall,
|
||||
select_streaming_callback,
|
||||
)
|
||||
from haystack.dataclasses.streaming_chunk import FinishReason
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import (
|
||||
Tool,
|
||||
@ -41,6 +42,7 @@ with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"
|
||||
ChatCompletionOutput,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionStreamOutput,
|
||||
ChatCompletionStreamOutputChoice,
|
||||
InferenceClient,
|
||||
)
|
||||
|
||||
@ -110,6 +112,43 @@ def _convert_tools_to_hfapi_tools(
|
||||
return hf_tools
|
||||
|
||||
|
||||
def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]:
|
||||
"""
|
||||
Map HuggingFace finish reasons to Haystack FinishReason literals.
|
||||
|
||||
Uses the full choice object to detect tool calls and provide accurate mapping.
|
||||
|
||||
HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
|
||||
FinishReason):
|
||||
- "length": number of generated tokens == `max_new_tokens`
|
||||
- "eos_token": the model generated its end of sequence token
|
||||
- "stop_sequence": the model generated a text included in `stop_sequences`
|
||||
|
||||
Additionally detects tool calls from delta.tool_calls or delta.tool_call_id.
|
||||
|
||||
:param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
|
||||
:returns: The corresponding Haystack FinishReason or None.
|
||||
"""
|
||||
if choice.finish_reason is None:
|
||||
return None
|
||||
|
||||
# Check if this choice contains tool call information
|
||||
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
|
||||
|
||||
# If we detect tool calls, override the finish reason
|
||||
if has_tool_calls:
|
||||
return "tool_calls"
|
||||
|
||||
# Map HuggingFace finish reasons to Haystack standard ones
|
||||
mapping: Dict[str, FinishReason] = {
|
||||
"length": "length", # Direct match
|
||||
"eos_token": "stop", # EOS token means natural stop
|
||||
"stop_sequence": "stop", # Stop sequence means natural stop
|
||||
}
|
||||
|
||||
return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons
|
||||
|
||||
|
||||
def _convert_chat_completion_stream_output_to_streaming_chunk(
|
||||
chunk: "ChatCompletionStreamOutput",
|
||||
previous_chunks: List[StreamingChunk],
|
||||
@ -133,6 +172,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
|
||||
# the argument is probably allowed for compatibility with OpenAI
|
||||
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
|
||||
choice = chunk.choices[0]
|
||||
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
|
||||
stream_chunk = StreamingChunk(
|
||||
content=choice.delta.content or "",
|
||||
meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason},
|
||||
@ -141,6 +181,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
|
||||
index=0 if choice.finish_reason is None else None,
|
||||
# start is True at the very beginning since first chunk contains role information + first part of the answer.
|
||||
start=len(previous_chunks) == 0,
|
||||
finish_reason=mapped_finish_reason,
|
||||
)
|
||||
return stream_chunk
|
||||
|
||||
|
@ -18,6 +18,7 @@ from haystack.dataclasses import (
|
||||
AsyncStreamingCallbackT,
|
||||
ChatMessage,
|
||||
ComponentInfo,
|
||||
FinishReason,
|
||||
StreamingCallbackT,
|
||||
StreamingChunk,
|
||||
SyncStreamingCallbackT,
|
||||
@ -517,8 +518,15 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
|
||||
generated the chunk, such as the component name and type.
|
||||
|
||||
:returns:
|
||||
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
|
||||
A StreamingChunk object representing the content of the chunk from the OpenAI API.
|
||||
"""
|
||||
finish_reason_mapping: Dict[str, FinishReason] = {
|
||||
"stop": "stop",
|
||||
"length": "length",
|
||||
"content_filter": "content_filter",
|
||||
"tool_calls": "tool_calls",
|
||||
"function_call": "tool_calls",
|
||||
}
|
||||
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
|
||||
# Choices is empty if include_usage is set to True where the usage information is returned.
|
||||
if len(chunk.choices) == 0:
|
||||
@ -527,6 +535,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
|
||||
component_info=component_info,
|
||||
# Index is None since it's only set to an int when a content block is present
|
||||
index=None,
|
||||
finish_reason=None,
|
||||
meta={
|
||||
"model": chunk.model,
|
||||
"received_at": datetime.now().isoformat(),
|
||||
@ -556,6 +565,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
|
||||
index=tool_calls_deltas[0].index,
|
||||
tool_calls=tool_calls_deltas,
|
||||
start=tool_calls_deltas[0].tool_name is not None,
|
||||
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
|
||||
meta={
|
||||
"model": chunk.model,
|
||||
"index": choice.index,
|
||||
@ -584,6 +594,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
|
||||
# The first chunk is always a start message chunk that only contains role information, so if we reach here
|
||||
# and previous_chunks is length 1 then this is the start of text content.
|
||||
start=len(previous_chunks) == 1,
|
||||
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
|
||||
meta={
|
||||
"model": chunk.model,
|
||||
"index": choice.index,
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union, cast
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.dataclasses import (
|
||||
ComponentInfo,
|
||||
FinishReason,
|
||||
StreamingCallbackT,
|
||||
StreamingChunk,
|
||||
SyncStreamingCallbackT,
|
||||
@ -241,8 +242,21 @@ class HuggingFaceAPIGenerator:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = datetime.now().isoformat()
|
||||
|
||||
mapping: Dict[str, FinishReason] = {
|
||||
"length": "length", # Direct match
|
||||
"eos_token": "stop", # EOS token means natural stop
|
||||
"stop_sequence": "stop", # Stop sequence means natural stop
|
||||
}
|
||||
mapped_finish_reason = (
|
||||
mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None
|
||||
)
|
||||
stream_chunk = StreamingChunk(
|
||||
content=token.text, meta=chunk_metadata, component_info=component_info, index=0, start=len(chunks) == 0
|
||||
content=token.text,
|
||||
meta=chunk_metadata,
|
||||
component_info=component_info,
|
||||
index=0,
|
||||
start=len(chunks) == 0,
|
||||
finish_reason=mapped_finish_reason,
|
||||
)
|
||||
chunks.append(stream_chunk)
|
||||
streaming_callback(stream_chunk)
|
||||
|
@ -65,7 +65,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
|
||||
|
||||
# End of LLM assistant message so we add two new lines
|
||||
# This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results
|
||||
if chunk.meta.get("finish_reason") is not None:
|
||||
if chunk.finish_reason is not None:
|
||||
print("\n\n", flush=True, end="")
|
||||
|
||||
|
||||
@ -121,9 +121,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
|
||||
)
|
||||
|
||||
# finish_reason can appear in different places so we look for the last one
|
||||
finish_reasons = [
|
||||
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
|
||||
]
|
||||
finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason]
|
||||
finish_reason = finish_reasons[-1] if finish_reasons else None
|
||||
|
||||
meta = {
|
||||
|
@ -553,7 +553,11 @@ class ToolInvoker:
|
||||
|
||||
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
||||
if len(tool_messages) > 0 and streaming_callback is not None:
|
||||
streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
|
||||
streaming_callback(
|
||||
StreamingChunk(
|
||||
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
|
||||
)
|
||||
)
|
||||
|
||||
return {"tool_messages": tool_messages, "state": state}
|
||||
|
||||
@ -685,7 +689,11 @@ class ToolInvoker:
|
||||
|
||||
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
||||
if len(tool_messages) > 0 and streaming_callback is not None:
|
||||
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
|
||||
await streaming_callback(
|
||||
StreamingChunk(
|
||||
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
|
||||
)
|
||||
)
|
||||
|
||||
return {"tool_messages": tool_messages, "state": state}
|
||||
|
||||
|
@ -17,6 +17,7 @@ _import_structure = {
|
||||
"streaming_chunk": [
|
||||
"AsyncStreamingCallbackT",
|
||||
"ComponentInfo",
|
||||
"FinishReason",
|
||||
"StreamingCallbackT",
|
||||
"StreamingChunk",
|
||||
"SyncStreamingCallbackT",
|
||||
@ -40,6 +41,7 @@ if TYPE_CHECKING:
|
||||
from .state import State as State
|
||||
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
|
||||
from .streaming_chunk import ComponentInfo as ComponentInfo
|
||||
from .streaming_chunk import FinishReason as FinishReason
|
||||
from .streaming_chunk import StreamingCallbackT as StreamingCallbackT
|
||||
from .streaming_chunk import StreamingChunk as StreamingChunk
|
||||
from .streaming_chunk import SyncStreamingCallbackT as SyncStreamingCallbackT
|
||||
|
@ -9,6 +9,10 @@ from haystack.core.component import Component
|
||||
from haystack.dataclasses.chat_message import ToolCallResult
|
||||
from haystack.utils.asynchronous import is_callable_async_compatible
|
||||
|
||||
# Type alias for standard finish_reason values following OpenAI's convention
|
||||
# plus Haystack-specific value ("tool_call_results")
|
||||
FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallDelta:
|
||||
@ -77,6 +81,9 @@ class StreamingChunk:
|
||||
chunk.
|
||||
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
|
||||
:param start: A boolean indicating whether this chunk marks the start of a content block.
|
||||
:param finish_reason: An optional value indicating the reason the generation finished.
|
||||
Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter",
|
||||
plus Haystack-specific value "tool_call_results".
|
||||
"""
|
||||
|
||||
content: str
|
||||
@ -86,6 +93,7 @@ class StreamingChunk:
|
||||
tool_calls: Optional[List[ToolCallDelta]] = field(default=None)
|
||||
tool_call_result: Optional[ToolCallResult] = field(default=None)
|
||||
start: bool = field(default=False)
|
||||
finish_reason: Optional[FinishReason] = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result))
|
||||
|
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added dedicated `finish_reason` field to `StreamingChunk` class to improve type safety and enable sophisticated streaming UI logic. The field uses a `FinishReason` type alias with standard values: "stop", "length", "tool_calls", "content_filter", plus Haystack-specific value "tool_call_results" (used by ToolInvoker to indicate tool execution completion).
|
||||
- |
|
||||
Updated `ToolInvoker` component to use the new `finish_reason` field when streaming tool results. The component now sets `finish_reason="tool_call_results"` in the final streaming chunk to indicate that tool execution has completed, while maintaining backward compatibility by also setting the value in `meta["finish_reason"]`.
|
@ -711,6 +711,7 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
"model": "microsoft/Phi-3.5-mini-instruct",
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
finish_reason="stop",
|
||||
),
|
||||
[0],
|
||||
),
|
||||
|
@ -1143,6 +1143,7 @@ def streaming_chunks():
|
||||
"received_at": ANY,
|
||||
"usage": None,
|
||||
},
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -1174,7 +1175,7 @@ class TestChatCompletionChunkConversion:
|
||||
chunk=openai_chunk, previous_chunks=previous_chunks
|
||||
)
|
||||
assert stream_chunk == haystack_chunk
|
||||
previous_chunks.append(openai_chunk)
|
||||
previous_chunks.append(stream_chunk)
|
||||
|
||||
def test_handle_stream_response(self, chat_completion_chunks):
|
||||
openai_chunks = chat_completion_chunks
|
||||
|
@ -266,6 +266,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
||||
"received_at": "2025-02-19T16:02:55.948772",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
|
@ -203,6 +203,28 @@ class TestToolInvoker:
|
||||
assert tool_call_result.origin == tool_call
|
||||
assert not tool_call_result.error
|
||||
|
||||
def test_run_with_streaming_callback_finish_reason(self, invoker):
|
||||
streaming_chunks = []
|
||||
|
||||
def streaming_callback(chunk: StreamingChunk) -> None:
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
result = invoker.run(messages=[message], streaming_callback=streaming_callback)
|
||||
assert "tool_messages" in result
|
||||
assert len(result["tool_messages"]) == 1
|
||||
|
||||
# Check that we received streaming chunks
|
||||
assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason
|
||||
|
||||
# The last chunk should have finish_reason set to "tool_call_results"
|
||||
final_chunk = streaming_chunks[-1]
|
||||
assert final_chunk.finish_reason == "tool_call_results"
|
||||
assert final_chunk.meta["finish_reason"] == "tool_call_results"
|
||||
assert final_chunk.content == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_streaming_callback(self, thread_executor, weather_tool):
|
||||
streaming_callback_called = False
|
||||
@ -245,6 +267,36 @@ class TestToolInvoker:
|
||||
# check we called the streaming callback
|
||||
assert streaming_callback_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_streaming_callback_finish_reason(self, thread_executor, weather_tool):
|
||||
streaming_chunks = []
|
||||
|
||||
async def streaming_callback(chunk: StreamingChunk) -> None:
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
tool_invoker = ToolInvoker(
|
||||
tools=[weather_tool],
|
||||
raise_on_failure=True,
|
||||
convert_result_to_json_string=False,
|
||||
async_executor=thread_executor,
|
||||
)
|
||||
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
result = await tool_invoker.run_async(messages=[message], streaming_callback=streaming_callback)
|
||||
assert "tool_messages" in result
|
||||
assert len(result["tool_messages"]) == 1
|
||||
|
||||
# Check that we received streaming chunks
|
||||
assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason
|
||||
|
||||
# The last chunk should have finish_reason set to "tool_call_results"
|
||||
final_chunk = streaming_chunks[-1]
|
||||
assert final_chunk.finish_reason == "tool_call_results"
|
||||
assert final_chunk.meta["finish_reason"] == "tool_call_results"
|
||||
assert final_chunk.content == ""
|
||||
|
||||
def test_run_with_toolset(self, tool_set):
|
||||
tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False)
|
||||
tool_call = ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3})
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall
|
||||
from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall, FinishReason
|
||||
from haystack import component
|
||||
from haystack import Pipeline
|
||||
|
||||
@ -102,3 +102,42 @@ def test_tool_call_delta():
|
||||
def test_tool_call_delta_with_missing_fields():
|
||||
with pytest.raises(ValueError):
|
||||
_ = ToolCallDelta(id="123", index=0)
|
||||
|
||||
|
||||
def test_create_chunk_with_finish_reason():
|
||||
"""Test creating a chunk with the new finish_reason field."""
|
||||
chunk = StreamingChunk(content="Test content", finish_reason="stop")
|
||||
|
||||
assert chunk.content == "Test content"
|
||||
assert chunk.finish_reason == "stop"
|
||||
assert chunk.meta == {}
|
||||
|
||||
|
||||
def test_create_chunk_with_finish_reason_and_meta():
|
||||
"""Test creating a chunk with both finish_reason field and meta."""
|
||||
chunk = StreamingChunk(
|
||||
content="Test content", finish_reason="stop", meta={"model": "gpt-4", "usage": {"tokens": 10}}
|
||||
)
|
||||
|
||||
assert chunk.content == "Test content"
|
||||
assert chunk.finish_reason == "stop"
|
||||
assert chunk.meta["model"] == "gpt-4"
|
||||
assert chunk.meta["usage"]["tokens"] == 10
|
||||
|
||||
|
||||
def test_finish_reason_standard_values():
|
||||
"""Test all standard finish_reason values including the new Haystack-specific ones."""
|
||||
standard_values = ["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
|
||||
|
||||
for value in standard_values:
|
||||
chunk = StreamingChunk(content="Test content", finish_reason=value)
|
||||
assert chunk.finish_reason == value
|
||||
|
||||
|
||||
def test_finish_reason_tool_call_results():
|
||||
"""Test specifically the new tool_call_results finish reason."""
|
||||
chunk = StreamingChunk(content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"})
|
||||
|
||||
assert chunk.finish_reason == "tool_call_results"
|
||||
assert chunk.meta["finish_reason"] == "tool_call_results"
|
||||
assert chunk.content == ""
|
||||
|
Loading…
x
Reference in New Issue
Block a user