From b8dff932008826135f04a221a25de9bd156bc30c Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Mon, 26 May 2025 15:39:59 +0200 Subject: [PATCH] chore: Fix Streaming Callback types (#9441) * Fix types * Add select_streaming_callback --- haystack/components/generators/azure.py | 6 +++--- .../generators/chat/hugging_face_local.py | 10 +++++----- haystack/components/generators/hugging_face_api.py | 14 ++++++++------ .../components/generators/hugging_face_local.py | 12 +++++++----- haystack/components/generators/openai.py | 12 +++++++----- haystack/utils/hf.py | 8 ++++---- 6 files changed, 34 insertions(+), 28 deletions(-) diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 2926d9c73..358bab256 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -3,13 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Optional from openai.lib.azure import AzureADTokenProvider, AzureOpenAI from haystack import component, default_from_dict, default_to_dict from haystack.components.generators import OpenAIGenerator -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import StreamingCallbackT from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.http_client import init_http_client @@ -63,7 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False), azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False), organization: Optional[str] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, system_prompt: Optional[str] = None, timeout: Optional[float] = None, max_retries: Optional[int] = None, diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 85269c512..33d25e155 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback +from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, @@ -130,7 +130,7 @@ class HuggingFaceLocalChatGenerator: generation_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, tools: Optional[Union[List[Tool], Toolset]] = None, tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, async_executor: Optional[ThreadPoolExecutor] = None, @@ -330,7 +330,7 @@ class HuggingFaceLocalChatGenerator: self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, tools: Optional[Union[List[Tool], Toolset]] = None, ): """ @@ -492,7 +492,7 @@ class HuggingFaceLocalChatGenerator: self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, tools: Optional[Union[List[Tool], Toolset]] = None, ): """ @@ -546,7 +546,7 @@ class HuggingFaceLocalChatGenerator: tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], generation_kwargs: Dict[str, Any], stop_words: Optional[List[str]], - streaming_callback: Callable[[StreamingChunk], None], + streaming_callback: StreamingCallbackT, ): """ Handles async streaming generation of responses. diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index f55db685d..a02ac6f9a 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -4,10 +4,10 @@ from dataclasses import asdict from datetime import datetime -from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast +from typing import Any, Dict, Iterable, List, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import StreamingCallbackT, StreamingChunk, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model @@ -80,7 +80,7 @@ class HuggingFaceAPIGenerator: token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, ): """ Initialize the HuggingFaceAPIGenerator instance. @@ -180,7 +180,7 @@ class HuggingFaceAPIGenerator: def run( self, prompt: str, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -200,7 +200,9 @@ class HuggingFaceAPIGenerator: generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} # check if streaming_callback is passed - streaming_callback = streaming_callback or self.streaming_callback + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) hf_output = self._client.text_generation( prompt, details=True, stream=streaming_callback is not None, **generation_kwargs @@ -213,7 +215,7 @@ class HuggingFaceAPIGenerator: return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output)) def _stream_and_build_response( - self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None] + self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT ): chunks: List[StreamingChunk] = [] first_chunk_time = None diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index e7c3605c4..36191c8c6 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Literal, Optional, cast +from typing import Any, Dict, List, Literal, Optional, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import StreamingCallbackT, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.utils import ( ComponentDevice, @@ -63,7 +63,7 @@ class HuggingFaceLocalGenerator: generation_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, ): """ Creates an instance of a HuggingFaceLocalGenerator. @@ -210,7 +210,7 @@ class HuggingFaceLocalGenerator: def run( self, prompt: str, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -239,7 +239,9 @@ class HuggingFaceLocalGenerator: updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} # check if streaming_callback is passed - streaming_callback = streaming_callback or self.streaming_callback + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) if streaming_callback: num_responses = updated_generation_kwargs.get("num_return_sequences", 1) diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 72455a1b9..02ba20c8e 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -4,13 +4,13 @@ import os from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from openai import OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.http_client import init_http_client @@ -54,7 +54,7 @@ class OpenAIGenerator: self, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), model: str = "gpt-4o-mini", - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, api_base_url: Optional[str] = None, organization: Optional[str] = None, system_prompt: Optional[str] = None, @@ -178,7 +178,7 @@ class OpenAIGenerator: self, prompt: str, system_prompt: Optional[str] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -211,7 +211,9 @@ class OpenAIGenerator: generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} # check if streaming_callback is passed - streaming_callback = streaming_callback or self.streaming_callback + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) # adapt ChatMessage(s) to the format expected by the OpenAI API openai_formatted_messages = [message.to_openai_dict_format() for message in messages] diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index c5b2a612b..946bcaead 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -4,10 +4,10 @@ import copy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from haystack import logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice @@ -349,7 +349,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator. Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming - of generated text via Haystack Callable[StreamingChunk, None] callbacks. + of generated text via Haystack StreamingCallbackT callbacks. Do not use this class directly. """ @@ -357,7 +357,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - stream_handler: Callable[[StreamingChunk], None], + stream_handler: StreamingCallbackT, stop_words: Optional[List[str]] = None, ): super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore