mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-05 19:36:55 +00:00
chore: Fix Streaming Callback types (#9441)
* Fix types * Add select_streaming_callback
This commit is contained in:
parent
c82a3377f2
commit
b8dff93200
@ -3,13 +3,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict
|
from haystack import component, default_from_dict, default_to_dict
|
||||||
from haystack.components.generators import OpenAIGenerator
|
from haystack.components.generators import OpenAIGenerator
|
||||||
from haystack.dataclasses import StreamingChunk
|
from haystack.dataclasses import StreamingCallbackT
|
||||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||||
from haystack.utils.http_client import init_http_client
|
from haystack.utils.http_client import init_http_client
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
|||||||
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
||||||
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.tools import (
|
from haystack.tools import (
|
||||||
Tool,
|
Tool,
|
||||||
@ -130,7 +130,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stop_words: Optional[List[str]] = None,
|
stop_words: Optional[List[str]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||||
@ -330,7 +330,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
self,
|
self,
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -492,7 +492,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
self,
|
self,
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -546,7 +546,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||||
generation_kwargs: Dict[str, Any],
|
generation_kwargs: Dict[str, Any],
|
||||||
stop_words: Optional[List[str]],
|
stop_words: Optional[List[str]],
|
||||||
streaming_callback: Callable[[StreamingChunk], None],
|
streaming_callback: StreamingCallbackT,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handles async streaming generation of responses.
|
Handles async streaming generation of responses.
|
||||||
|
|||||||
@ -4,10 +4,10 @@
|
|||||||
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast
|
from typing import Any, Dict, Iterable, List, Optional, Union, cast
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict
|
from haystack import component, default_from_dict, default_to_dict
|
||||||
from haystack.dataclasses import StreamingChunk
|
from haystack.dataclasses import StreamingCallbackT, StreamingChunk, select_streaming_callback
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
|
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
|
||||||
@ -80,7 +80,7 @@ class HuggingFaceAPIGenerator:
|
|||||||
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stop_words: Optional[List[str]] = None,
|
stop_words: Optional[List[str]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the HuggingFaceAPIGenerator instance.
|
Initialize the HuggingFaceAPIGenerator instance.
|
||||||
@ -180,7 +180,7 @@ class HuggingFaceAPIGenerator:
|
|||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -200,7 +200,9 @@ class HuggingFaceAPIGenerator:
|
|||||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
# check if streaming_callback is passed
|
# check if streaming_callback is passed
|
||||||
streaming_callback = streaming_callback or self.streaming_callback
|
streaming_callback = select_streaming_callback(
|
||||||
|
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||||
|
)
|
||||||
|
|
||||||
hf_output = self._client.text_generation(
|
hf_output = self._client.text_generation(
|
||||||
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
|
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
|
||||||
@ -213,7 +215,7 @@ class HuggingFaceAPIGenerator:
|
|||||||
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
|
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
|
||||||
|
|
||||||
def _stream_and_build_response(
|
def _stream_and_build_response(
|
||||||
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
|
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT
|
||||||
):
|
):
|
||||||
chunks: List[StreamingChunk] = []
|
chunks: List[StreamingChunk] = []
|
||||||
first_chunk_time = None
|
first_chunk_time = None
|
||||||
|
|||||||
@ -2,10 +2,10 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, cast
|
from typing import Any, Dict, List, Literal, Optional, cast
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import StreamingChunk
|
from haystack.dataclasses import StreamingCallbackT, select_streaming_callback
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import (
|
from haystack.utils import (
|
||||||
ComponentDevice,
|
ComponentDevice,
|
||||||
@ -63,7 +63,7 @@ class HuggingFaceLocalGenerator:
|
|||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stop_words: Optional[List[str]] = None,
|
stop_words: Optional[List[str]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an instance of a HuggingFaceLocalGenerator.
|
Creates an instance of a HuggingFaceLocalGenerator.
|
||||||
@ -210,7 +210,7 @@ class HuggingFaceLocalGenerator:
|
|||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -239,7 +239,9 @@ class HuggingFaceLocalGenerator:
|
|||||||
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
# check if streaming_callback is passed
|
# check if streaming_callback is passed
|
||||||
streaming_callback = streaming_callback or self.streaming_callback
|
streaming_callback = select_streaming_callback(
|
||||||
|
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||||
|
)
|
||||||
|
|
||||||
if streaming_callback:
|
if streaming_callback:
|
||||||
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
||||||
|
|||||||
@ -4,13 +4,13 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from openai import OpenAI, Stream
|
from openai import OpenAI, Stream
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback
|
||||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||||
from haystack.utils.http_client import init_http_client
|
from haystack.utils.http_client import init_http_client
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class OpenAIGenerator:
|
|||||||
self,
|
self,
|
||||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
@ -178,7 +178,7 @@ class OpenAIGenerator:
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -211,7 +211,9 @@ class OpenAIGenerator:
|
|||||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
# check if streaming_callback is passed
|
# check if streaming_callback is passed
|
||||||
streaming_callback = streaming_callback or self.streaming_callback
|
streaming_callback = select_streaming_callback(
|
||||||
|
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||||
|
)
|
||||||
|
|
||||||
# adapt ChatMessage(s) to the format expected by the OpenAI API
|
# adapt ChatMessage(s) to the format expected by the OpenAI API
|
||||||
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
|
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
|
||||||
|
|||||||
@ -4,10 +4,10 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from haystack import logging
|
from haystack import logging
|
||||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
from haystack.utils.device import ComponentDevice
|
from haystack.utils.device import ComponentDevice
|
||||||
@ -349,7 +349,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
|
|||||||
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
|
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
|
||||||
|
|
||||||
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
|
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
|
||||||
of generated text via Haystack Callable[StreamingChunk, None] callbacks.
|
of generated text via Haystack StreamingCallbackT callbacks.
|
||||||
|
|
||||||
Do not use this class directly.
|
Do not use this class directly.
|
||||||
"""
|
"""
|
||||||
@ -357,7 +357,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transfor
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
stream_handler: Callable[[StreamingChunk], None],
|
stream_handler: StreamingCallbackT,
|
||||||
stop_words: Optional[List[str]] = None,
|
stop_words: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user