mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-04 02:57:34 +00:00
feat: add component name and type to StreamingChunk (#9426)
* Stream component name in openai * Fix type * PR comments * Update huggingface gen * Typing fix * Update huggingfacelocal gen * Fix errors * Remove model changes * Fix minor errors * Update releasenotes/notes/add-component-info-dataclass-be115dee2fa50abd.yaml Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * PR comments * update annotation * Update hf files * Fix linting * Add a from_component method * use add_component --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
This commit is contained in:
parent
085c3add41
commit
64def6d41b
@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import (
|
||||
@ -409,6 +409,10 @@ class HuggingFaceAPIChatGenerator:
|
||||
usage = None
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# get the component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
|
||||
# Set up streaming handler
|
||||
for chunk in api_output:
|
||||
# The chunk with usage returns an empty array for choices
|
||||
if len(chunk.choices) > 0:
|
||||
@ -423,7 +427,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info)
|
||||
streaming_callback(stream_chunk)
|
||||
|
||||
if chunk.usage:
|
||||
@ -505,6 +509,9 @@ class HuggingFaceAPIChatGenerator:
|
||||
usage = None
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# get the component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
|
||||
async for chunk in api_output:
|
||||
# The chunk with usage returns an empty array for choices
|
||||
if len(chunk.choices) > 0:
|
||||
@ -516,10 +523,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
text = choice.delta.content or ""
|
||||
generated_text += text
|
||||
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info)
|
||||
await streaming_callback(stream_chunk) # type: ignore
|
||||
|
||||
if chunk.usage:
|
||||
|
||||
@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
|
||||
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import (
|
||||
Tool,
|
||||
@ -384,8 +384,13 @@ class HuggingFaceLocalChatGenerator:
|
||||
)
|
||||
logger.warning(msg, num_responses=num_responses)
|
||||
generation_kwargs["num_return_sequences"] = 1
|
||||
|
||||
# Get component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||
tokenizer, streaming_callback, stop_words, component_info
|
||||
)
|
||||
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
@ -573,8 +578,11 @@ class HuggingFaceLocalChatGenerator:
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Set up streaming handler
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
|
||||
# get the component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||
tokenizer, streaming_callback, stop_words, component_info
|
||||
)
|
||||
|
||||
# Generate responses asynchronously
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
|
||||
@ -16,6 +16,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import (
|
||||
AsyncStreamingCallbackT,
|
||||
ChatMessage,
|
||||
ComponentInfo,
|
||||
StreamingCallbackT,
|
||||
StreamingChunk,
|
||||
SyncStreamingCallbackT,
|
||||
@ -570,14 +571,23 @@ class OpenAIChatGenerator:
|
||||
:returns:
|
||||
The StreamingChunk.
|
||||
"""
|
||||
# if there are no choices, return an empty chunk
|
||||
if len(chunk.choices) == 0:
|
||||
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})
|
||||
|
||||
# get the component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
|
||||
# we stream the content of the chunk if it's not a tool or function call
|
||||
# if there are no choices, return an empty chunk
|
||||
if len(chunk.choices) == 0:
|
||||
return StreamingChunk(
|
||||
content="",
|
||||
meta={"model": chunk.model, "received_at": datetime.now().isoformat()},
|
||||
component_info=component_info,
|
||||
)
|
||||
|
||||
choice: ChunkChoice = chunk.choices[0]
|
||||
content = choice.delta.content or ""
|
||||
chunk_message = StreamingChunk(content)
|
||||
chunk_message = StreamingChunk(content, component_info=component_info)
|
||||
|
||||
# but save the tool calls and function call in the meta if they are present
|
||||
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
|
||||
chunk_message.meta.update(
|
||||
|
||||
@ -366,6 +366,7 @@ class PipelineBase:
|
||||
raise PipelineError(msg)
|
||||
|
||||
setattr(instance, "__haystack_added_to_pipeline__", self)
|
||||
setattr(instance, "__component_name__", name)
|
||||
|
||||
# Add component to the graph, disconnected
|
||||
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
|
||||
|
||||
@ -20,6 +20,7 @@ _import_structure = {
|
||||
"StreamingCallbackT",
|
||||
"SyncStreamingCallbackT",
|
||||
"select_streaming_callback",
|
||||
"ComponentInfo",
|
||||
],
|
||||
}
|
||||
|
||||
@ -32,6 +33,7 @@ if TYPE_CHECKING:
|
||||
from .state import State
|
||||
from .streaming_chunk import (
|
||||
AsyncStreamingCallbackT,
|
||||
ComponentInfo,
|
||||
StreamingCallbackT,
|
||||
StreamingChunk,
|
||||
SyncStreamingCallbackT,
|
||||
|
||||
@ -5,22 +5,54 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from haystack.core.component import Component
|
||||
from haystack.utils.asynchronous import is_callable_async_compatible
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentInfo:
|
||||
"""
|
||||
The `ComponentInfo` class encapsulates information about a component.
|
||||
|
||||
:param type: The type of the component.
|
||||
:param name: The name of the component assigned when adding it to a pipeline.
|
||||
|
||||
"""
|
||||
|
||||
type: str
|
||||
name: Optional[str] = field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_component(cls, component: Component) -> "ComponentInfo":
|
||||
"""
|
||||
Create a `ComponentInfo` object from a `Component` instance.
|
||||
|
||||
:param component:
|
||||
The `Component` instance.
|
||||
:returns:
|
||||
The `ComponentInfo` object with the type and name of the given component.
|
||||
"""
|
||||
component_type = f"{component.__class__.__module__}.{component.__class__.__name__}"
|
||||
component_name = getattr(component, "__component_name__", None)
|
||||
return cls(type=component_type, name=component_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingChunk:
|
||||
"""
|
||||
The StreamingChunk class encapsulates a segment of streamed content along with associated metadata.
|
||||
The `StreamingChunk` class encapsulates a segment of streamed content along with associated metadata.
|
||||
|
||||
This structure facilitates the handling and processing of streamed data in a systematic manner.
|
||||
|
||||
:param content: The content of the message chunk as a string.
|
||||
:param meta: A dictionary containing metadata related to the message chunk.
|
||||
:param component_info: A `ComponentInfo` object containing information about the component that generated the chunk,
|
||||
such as the component name and type.
|
||||
"""
|
||||
|
||||
content: str
|
||||
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
|
||||
component_info: Optional[ComponentInfo] = field(default=None, hash=False)
|
||||
|
||||
|
||||
SyncStreamingCallbackT = Callable[[StreamingChunk], None]
|
||||
|
||||
@ -7,7 +7,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
|
||||
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.device import ComponentDevice
|
||||
@ -359,13 +359,15 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stream_handler: StreamingCallbackT,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
component_info: Optional[ComponentInfo] = None,
|
||||
):
|
||||
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
||||
self.token_handler = stream_handler
|
||||
self.stop_words = stop_words or []
|
||||
self.component_info = component_info
|
||||
|
||||
def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
|
||||
"""Callback function for handling the generated text."""
|
||||
word_to_send = word + "\n" if stream_end else word
|
||||
if word_to_send.strip() not in self.stop_words:
|
||||
self.token_handler(StreamingChunk(content=word_to_send))
|
||||
self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info))
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
- Add a `ComponentInfo` dataclass to the `haystack.dataclasses` module.
|
||||
This dataclass is used to store information about the component. We pass it to `StreamingChunk` so we can tell from which component a stream is coming from.
|
||||
|
||||
- Pass the `component_info` to the `StreamingChunk` in the `OpenAIChatGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceAPIChatGenerator` and `HuggingFaceLocalChatGenerator`.
|
||||
@ -20,7 +20,7 @@ from openai.types.chat import chat_completion_chunk
|
||||
|
||||
from haystack import component
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.dataclasses import StreamingChunk, ComponentInfo
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools import ComponentTool, Tool
|
||||
@ -625,6 +625,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.910076",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -644,6 +645,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.913919",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -661,6 +663,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.914439",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -678,6 +681,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.924146",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -695,6 +699,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.924420",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -712,6 +717,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.944398",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -729,6 +735,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.944958",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -763,6 +770,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.946018",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -782,6 +790,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.946578",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -799,6 +808,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.946981",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -816,6 +826,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.947411",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -833,6 +844,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.947643",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -850,6 +862,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": None,
|
||||
"received_at": "2025-02-19T16:02:55.947939",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
@ -860,6 +873,7 @@ class TestOpenAIChatGenerator:
|
||||
"finish_reason": "tool_calls",
|
||||
"received_at": "2025-02-19T16:02:55.948772",
|
||||
},
|
||||
component_info=ComponentInfo(name="test", type="test"),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -2,9 +2,21 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.dataclasses import StreamingChunk, ComponentInfo
|
||||
from unittest.mock import Mock
|
||||
from haystack.core.component import Component
|
||||
from haystack import component
|
||||
from haystack import Pipeline
|
||||
|
||||
|
||||
@component
|
||||
class TestComponent:
|
||||
def __init__(self):
|
||||
self.name = "test_component"
|
||||
|
||||
def run(self) -> str:
|
||||
return "Test content"
|
||||
|
||||
|
||||
def test_create_chunk_with_content_and_metadata():
|
||||
@ -30,3 +42,27 @@ def test_create_chunk_with_empty_content():
|
||||
chunk = StreamingChunk(content="")
|
||||
assert chunk.content == ""
|
||||
assert chunk.meta == {}
|
||||
|
||||
|
||||
def test_create_chunk_with_all_fields():
|
||||
component_info = ComponentInfo(type="test.component", name="test_component")
|
||||
chunk = StreamingChunk(content="Test content", meta={"key": "value"}, component_info=component_info)
|
||||
|
||||
assert chunk.content == "Test content"
|
||||
assert chunk.meta == {"key": "value"}
|
||||
assert chunk.component_info == component_info
|
||||
|
||||
|
||||
def test_component_info_from_component():
|
||||
component = TestComponent()
|
||||
component_info = ComponentInfo.from_component(component)
|
||||
assert component_info.type == "test_streaming_chunk.TestComponent"
|
||||
|
||||
|
||||
def test_component_info_from_component_with_name_from_pipeline():
|
||||
pipeline = Pipeline()
|
||||
component = TestComponent()
|
||||
pipeline.add_component("pipeline_component", component)
|
||||
component_info = ComponentInfo.from_component(component)
|
||||
assert component_info.type == "test_streaming_chunk.TestComponent"
|
||||
assert component_info.name == "pipeline_component"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user