chore: Fix Streaming Callback types (#9441)

* Fix types

* Add select_streaming_callback
This commit is contained in:
Sebastian Husch Lee 2025-05-26 15:39:59 +02:00 committed by GitHub
parent c82a3377f2
commit b8dff93200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 34 additions and 28 deletions

View File

@ -3,13 +3,13 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Dict, Optional
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators import OpenAIGenerator
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingCallbackT
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
@ -63,7 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
system_prompt: Optional[str] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,

View File

@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
@ -130,7 +130,7 @@ class HuggingFaceLocalChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
async_executor: Optional[ThreadPoolExecutor] = None,
@ -330,7 +330,7 @@ class HuggingFaceLocalChatGenerator:
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
@ -492,7 +492,7 @@ class HuggingFaceLocalChatGenerator:
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
@ -546,7 +546,7 @@ class HuggingFaceLocalChatGenerator:
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
streaming_callback: Callable[[StreamingChunk], None],
streaming_callback: StreamingCallbackT,
):
"""
Handles async streaming generation of responses.

View File

@ -4,10 +4,10 @@
from dataclasses import asdict
from datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast
from typing import Any, Dict, Iterable, List, Optional, Union, cast
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
@ -80,7 +80,7 @@ class HuggingFaceAPIGenerator:
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
Initialize the HuggingFaceAPIGenerator instance.
@ -180,7 +180,7 @@ class HuggingFaceAPIGenerator:
def run(
self,
prompt: str,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
@ -200,7 +200,9 @@ class HuggingFaceAPIGenerator:
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
hf_output = self._client.text_generation(
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
@ -213,7 +215,7 @@ class HuggingFaceAPIGenerator:
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
def _stream_and_build_response(
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT
):
chunks: List[StreamingChunk] = []
first_chunk_time = None

View File

@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Literal, Optional, cast
from typing import Any, Dict, List, Literal, Optional, cast
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingCallbackT, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.utils import (
ComponentDevice,
@ -63,7 +63,7 @@ class HuggingFaceLocalGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
Creates an instance of a HuggingFaceLocalGenerator.
@ -210,7 +210,7 @@ class HuggingFaceLocalGenerator:
def run(
self,
prompt: str,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
@ -239,7 +239,9 @@ class HuggingFaceLocalGenerator:
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
if streaming_callback:
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)

View File

@ -4,13 +4,13 @@
import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
@ -54,7 +54,7 @@ class OpenAIGenerator:
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-4o-mini",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
system_prompt: Optional[str] = None,
@ -178,7 +178,7 @@ class OpenAIGenerator:
self,
prompt: str,
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
@ -211,7 +211,9 @@ class OpenAIGenerator:
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]

View File

@ -4,10 +4,10 @@
import copy
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from haystack import logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
@ -349,7 +349,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
of generated text via Haystack Callable[StreamingChunk, None] callbacks.
of generated text via Haystack StreamingCallbackT callbacks.
Do not use this class directly.
"""
@ -357,7 +357,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], None],
stream_handler: StreamingCallbackT,
stop_words: Optional[List[str]] = None,
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore