feat: add component name and type to StreamingChunk (#9426)

* Stream component name in openai

* Fix type

* PR comments

* Update huggingface gen

* Typing fix

* Update huggingfacelocal gen

* Fix errors

* Remove model changes

* Fix minor errors

* Update releasenotes/notes/add-component-info-dataclass-be115dee2fa50abd.yaml

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>

* PR comments

* update annotation

* Update hf files

* Fix linting

* Add a from_component method

* use add_component

---------

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
This commit is contained in:
Amna Mubashar 2025-05-27 12:23:40 +02:00 committed by GitHub
parent 085c3add41
commit 64def6d41b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 136 additions and 20 deletions

View File

@ -7,7 +7,7 @@ from datetime import datetime
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.lazy_imports import LazyImport
from haystack.tools import (
@ -409,6 +409,10 @@ class HuggingFaceAPIChatGenerator:
usage = None
meta: Dict[str, Any] = {}
# get the component name and type
component_info = ComponentInfo.from_component(self)
# Set up streaming handler
for chunk in api_output:
# The chunk with usage returns an empty array for choices
if len(chunk.choices) > 0:
@ -423,7 +427,7 @@ class HuggingFaceAPIChatGenerator:
if choice.finish_reason:
finish_reason = choice.finish_reason
stream_chunk = StreamingChunk(text, meta)
stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info)
streaming_callback(stream_chunk)
if chunk.usage:
@ -505,6 +509,9 @@ class HuggingFaceAPIChatGenerator:
usage = None
meta: Dict[str, Any] = {}
# get the component name and type
component_info = ComponentInfo.from_component(self)
async for chunk in api_output:
# The chunk with usage returns an empty array for choices
if len(chunk.choices) > 0:
@ -516,10 +523,7 @@ class HuggingFaceAPIChatGenerator:
text = choice.delta.content or ""
generated_text += text
if choice.finish_reason:
finish_reason = choice.finish_reason
stream_chunk = StreamingChunk(text, meta)
stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info)
await streaming_callback(stream_chunk) # type: ignore
if chunk.usage:

View File

@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
@ -384,8 +384,13 @@ class HuggingFaceLocalChatGenerator:
)
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1
# Get component name and type
component_info = ComponentInfo.from_component(self)
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
generation_kwargs["streamer"] = HFTokenStreamingHandler(
tokenizer, streaming_callback, stop_words, component_info
)
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
@ -573,8 +578,11 @@ class HuggingFaceLocalChatGenerator:
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# Set up streaming handler
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
# get the component name and type
component_info = ComponentInfo.from_component(self)
generation_kwargs["streamer"] = HFTokenStreamingHandler(
tokenizer, streaming_callback, stop_words, component_info
)
# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(

View File

@ -16,6 +16,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import (
AsyncStreamingCallbackT,
ChatMessage,
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
@ -570,14 +571,23 @@ class OpenAIChatGenerator:
:returns:
The StreamingChunk.
"""
# if there are no choices, return an empty chunk
if len(chunk.choices) == 0:
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})
# get the component name and type
component_info = ComponentInfo.from_component(self)
# we stream the content of the chunk if it's not a tool or function call
# if there are no choices, return an empty chunk
if len(chunk.choices) == 0:
return StreamingChunk(
content="",
meta={"model": chunk.model, "received_at": datetime.now().isoformat()},
component_info=component_info,
)
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
chunk_message = StreamingChunk(content, component_info=component_info)
# but save the tool calls and function call in the meta if they are present
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
chunk_message.meta.update(

View File

@ -366,6 +366,7 @@ class PipelineBase:
raise PipelineError(msg)
setattr(instance, "__haystack_added_to_pipeline__", self)
setattr(instance, "__component_name__", name)
# Add component to the graph, disconnected
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)

View File

@ -20,6 +20,7 @@ _import_structure = {
"StreamingCallbackT",
"SyncStreamingCallbackT",
"select_streaming_callback",
"ComponentInfo",
],
}
@ -32,6 +33,7 @@ if TYPE_CHECKING:
from .state import State
from .streaming_chunk import (
AsyncStreamingCallbackT,
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,

View File

@ -5,22 +5,54 @@
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from haystack.core.component import Component
from haystack.utils.asynchronous import is_callable_async_compatible
@dataclass
class ComponentInfo:
"""
The `ComponentInfo` class encapsulates information about a component.
:param type: The type of the component.
:param name: The name of the component assigned when adding it to a pipeline.
"""
type: str
name: Optional[str] = field(default=None)
@classmethod
def from_component(cls, component: Component) -> "ComponentInfo":
"""
Create a `ComponentInfo` object from a `Component` instance.
:param component:
The `Component` instance.
:returns:
The `ComponentInfo` object with the type and name of the given component.
"""
component_type = f"{component.__class__.__module__}.{component.__class__.__name__}"
component_name = getattr(component, "__component_name__", None)
return cls(type=component_type, name=component_name)
@dataclass
class StreamingChunk:
"""
The StreamingChunk class encapsulates a segment of streamed content along with associated metadata.
The `StreamingChunk` class encapsulates a segment of streamed content along with associated metadata.
This structure facilitates the handling and processing of streamed data in a systematic manner.
:param content: The content of the message chunk as a string.
:param meta: A dictionary containing metadata related to the message chunk.
:param component_info: A `ComponentInfo` object containing information about the component that generated the chunk,
such as the component name and type.
"""
content: str
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
component_info: Optional[ComponentInfo] = field(default=None, hash=False)
SyncStreamingCallbackT = Callable[[StreamingChunk], None]

View File

@ -7,7 +7,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional, Union
from haystack import logging
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
@ -359,13 +359,15 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: StreamingCallbackT,
stop_words: Optional[List[str]] = None,
component_info: Optional[ComponentInfo] = None,
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
self.stop_words = stop_words or []
self.component_info = component_info
def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
"""Callback function for handling the generated text."""
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
self.token_handler(StreamingChunk(content=word_to_send))
self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info))

View File

@ -0,0 +1,7 @@
---
features:
- |
- Add a `ComponentInfo` dataclass to the `haystack.dataclasses` module.
This dataclass is used to store information about the component. We pass it to `StreamingChunk` so we can tell from which component a stream is coming from.
- Pass the `component_info` to the `StreamingChunk` in the `OpenAIChatGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceAPIChatGenerator` and `HuggingFaceLocalChatGenerator`.

View File

@ -20,7 +20,7 @@ from openai.types.chat import chat_completion_chunk
from haystack import component
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingChunk, ComponentInfo
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import ComponentTool, Tool
@ -625,6 +625,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.910076",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -644,6 +645,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.913919",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -661,6 +663,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.914439",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -678,6 +681,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.924146",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -695,6 +699,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.924420",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -712,6 +717,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.944398",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -729,6 +735,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.944958",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -763,6 +770,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.946018",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -782,6 +790,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.946578",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -799,6 +808,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.946981",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -816,6 +826,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.947411",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -833,6 +844,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.947643",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -850,6 +862,7 @@ class TestOpenAIChatGenerator:
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.947939",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
@ -860,6 +873,7 @@ class TestOpenAIChatGenerator:
"finish_reason": "tool_calls",
"received_at": "2025-02-19T16:02:55.948772",
},
component_info=ComponentInfo(name="test", type="test"),
),
]

View File

@ -2,9 +2,21 @@
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingChunk, ComponentInfo
from unittest.mock import Mock
from haystack.core.component import Component
from haystack import component
from haystack import Pipeline
@component
class TestComponent:
def __init__(self):
self.name = "test_component"
def run(self) -> str:
return "Test content"
def test_create_chunk_with_content_and_metadata():
@ -30,3 +42,27 @@ def test_create_chunk_with_empty_content():
chunk = StreamingChunk(content="")
assert chunk.content == ""
assert chunk.meta == {}
def test_create_chunk_with_all_fields():
component_info = ComponentInfo(type="test.component", name="test_component")
chunk = StreamingChunk(content="Test content", meta={"key": "value"}, component_info=component_info)
assert chunk.content == "Test content"
assert chunk.meta == {"key": "value"}
assert chunk.component_info == component_info
def test_component_info_from_component():
component = TestComponent()
component_info = ComponentInfo.from_component(component)
assert component_info.type == "test_streaming_chunk.TestComponent"
def test_component_info_from_component_with_name_from_pipeline():
pipeline = Pipeline()
component = TestComponent()
pipeline.add_component("pipeline_component", component)
component_info = ComponentInfo.from_component(component)
assert component_info.type == "test_streaming_chunk.TestComponent"
assert component_info.name == "pipeline_component"