mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-04 19:06:44 +00:00
chore: Fix Streaming Callback types (#9441)
* Fix types * Add select_streaming_callback
This commit is contained in:
parent
c82a3377f2
commit
b8dff93200
@ -3,13 +3,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.dataclasses import StreamingCallbackT
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
@ -63,7 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
||||
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
||||
organization: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
|
||||
@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import (
|
||||
Tool,
|
||||
@ -130,7 +130,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||
@ -330,7 +330,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
@ -492,7 +492,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
@ -546,7 +546,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
streaming_callback: Callable[[StreamingChunk], None],
|
||||
streaming_callback: StreamingCallbackT,
|
||||
):
|
||||
"""
|
||||
Handles async streaming generation of responses.
|
||||
|
||||
@ -4,10 +4,10 @@
|
||||
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union, cast
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.dataclasses import StreamingCallbackT, StreamingChunk, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
|
||||
@ -80,7 +80,7 @@ class HuggingFaceAPIGenerator:
|
||||
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceAPIGenerator instance.
|
||||
@ -180,7 +180,7 @@ class HuggingFaceAPIGenerator:
|
||||
def run(
|
||||
self,
|
||||
prompt: str,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
@ -200,7 +200,9 @@ class HuggingFaceAPIGenerator:
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# check if streaming_callback is passed
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||
)
|
||||
|
||||
hf_output = self._client.text_generation(
|
||||
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
|
||||
@ -213,7 +215,7 @@ class HuggingFaceAPIGenerator:
|
||||
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
|
||||
|
||||
def _stream_and_build_response(
|
||||
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
|
||||
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT
|
||||
):
|
||||
chunks: List[StreamingChunk] = []
|
||||
first_chunk_time = None
|
||||
|
||||
@ -2,10 +2,10 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, cast
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.dataclasses import StreamingCallbackT, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import (
|
||||
ComponentDevice,
|
||||
@ -63,7 +63,7 @@ class HuggingFaceLocalGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of a HuggingFaceLocalGenerator.
|
||||
@ -210,7 +210,7 @@ class HuggingFaceLocalGenerator:
|
||||
def run(
|
||||
self,
|
||||
prompt: str,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
@ -239,7 +239,9 @@ class HuggingFaceLocalGenerator:
|
||||
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# check if streaming_callback is passed
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||
)
|
||||
|
||||
if streaming_callback:
|
||||
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
||||
|
||||
@ -4,13 +4,13 @@
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
@ -54,7 +54,7 @@ class OpenAIGenerator:
|
||||
self,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "gpt-4o-mini",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
@ -178,7 +178,7 @@ class OpenAIGenerator:
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
@ -211,7 +211,9 @@ class OpenAIGenerator:
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# check if streaming_callback is passed
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||
)
|
||||
|
||||
# adapt ChatMessage(s) to the format expected by the OpenAI API
|
||||
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
|
||||
|
||||
@ -4,10 +4,10 @@
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.device import ComponentDevice
|
||||
@ -349,7 +349,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
|
||||
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
|
||||
|
||||
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
|
||||
of generated text via Haystack Callable[StreamingChunk, None] callbacks.
|
||||
of generated text via Haystack StreamingCallbackT callbacks.
|
||||
|
||||
Do not use this class directly.
|
||||
"""
|
||||
@ -357,7 +357,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stream_handler: Callable[[StreamingChunk], None],
|
||||
stream_handler: StreamingCallbackT,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user