mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-16 18:13:54 +00:00
feat: Add Toolset support in ChatGenerator(s) (#9177)
* Add Toolset support in ChatGenerator(s) * Add reno note * Update azure test * Updates * Minor fix * Add more tests * Remove some integration tests * PR feedback * rm unused fixture --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
parent
63781afd8f
commit
e1e797206d
@ -11,7 +11,9 @@ from haystack import component, default_from_dict, default_to_dict
|
|||||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||||
|
from haystack.utils.misc import serialize_tools_or_toolset
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
@ -73,7 +75,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
|||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
default_headers: Optional[Dict[str, str]] = None,
|
default_headers: Optional[Dict[str, str]] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
tools_strict: bool = False,
|
tools_strict: bool = False,
|
||||||
*,
|
*,
|
||||||
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
|
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
|
||||||
@ -115,7 +117,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
|||||||
values are the bias to add to that token.
|
values are the bias to add to that token.
|
||||||
:param default_headers: Default headers to use for the AzureOpenAI client.
|
:param default_headers: Default headers to use for the AzureOpenAI client.
|
||||||
:param tools:
|
:param tools:
|
||||||
A list of tools for which the model can prepare calls.
|
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
|
||||||
|
list of `Tool` objects or a `Toolset` instance.
|
||||||
:param tools_strict:
|
:param tools_strict:
|
||||||
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
||||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||||
@ -152,7 +155,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
|||||||
self.default_headers = default_headers or {}
|
self.default_headers = default_headers or {}
|
||||||
self.azure_ad_token_provider = azure_ad_token_provider
|
self.azure_ad_token_provider = azure_ad_token_provider
|
||||||
|
|
||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(list(tools or []))
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.tools_strict = tools_strict
|
self.tools_strict = tools_strict
|
||||||
|
|
||||||
@ -196,7 +199,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
|||||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||||
default_headers=self.default_headers,
|
default_headers=self.default_headers,
|
||||||
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
|
tools=serialize_tools_or_toolset(self.tools),
|
||||||
tools_strict=self.tools_strict,
|
tools_strict=self.tools_strict,
|
||||||
azure_ad_token_provider=azure_ad_token_provider_name,
|
azure_ad_token_provider=azure_ad_token_provider_name,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,8 +10,10 @@ from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_s
|
|||||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
|
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
|
||||||
|
from haystack.utils.misc import serialize_tools_or_toolset
|
||||||
from haystack.utils.url_validation import is_valid_http_url
|
from haystack.utils.url_validation import is_valid_http_url
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
|
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
|
||||||
@ -103,7 +105,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stop_words: Optional[List[str]] = None,
|
stop_words: Optional[List[str]] = None,
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the HuggingFaceAPIChatGenerator instance.
|
Initialize the HuggingFaceAPIChatGenerator instance.
|
||||||
@ -130,10 +132,10 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
:param streaming_callback:
|
:param streaming_callback:
|
||||||
An optional callable for handling streaming responses.
|
An optional callable for handling streaming responses.
|
||||||
:param tools:
|
:param tools:
|
||||||
A list of tools for which the model can prepare calls.
|
A list of tools or a Toolset for which the model can prepare calls.
|
||||||
The chosen model should support tool/function calling, according to the model card.
|
The chosen model should support tool/function calling, according to the model card.
|
||||||
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
|
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
|
||||||
unexpected behavior.
|
unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
huggingface_hub_import.check()
|
huggingface_hub_import.check()
|
||||||
@ -166,7 +168,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
|
|
||||||
if tools and streaming_callback is not None:
|
if tools and streaming_callback is not None:
|
||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(list(tools or []))
|
||||||
|
|
||||||
# handle generation kwargs setup
|
# handle generation kwargs setup
|
||||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||||
@ -191,7 +193,6 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
A dictionary containing the serialized component.
|
A dictionary containing the serialized component.
|
||||||
"""
|
"""
|
||||||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
||||||
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
|
|
||||||
return default_to_dict(
|
return default_to_dict(
|
||||||
self,
|
self,
|
||||||
api_type=str(self.api_type),
|
api_type=str(self.api_type),
|
||||||
@ -199,7 +200,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
token=self.token.to_dict() if self.token else None,
|
token=self.token.to_dict() if self.token else None,
|
||||||
generation_kwargs=self.generation_kwargs,
|
generation_kwargs=self.generation_kwargs,
|
||||||
streaming_callback=callback_name,
|
streaming_callback=callback_name,
|
||||||
tools=serialized_tools,
|
tools=serialize_tools_or_toolset(self.tools),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -220,7 +221,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
self,
|
self,
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -231,8 +232,9 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
:param generation_kwargs:
|
:param generation_kwargs:
|
||||||
Additional keyword arguments for text generation.
|
Additional keyword arguments for text generation.
|
||||||
:param tools:
|
:param tools:
|
||||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||||
during component initialization.
|
the `tools` parameter set during component initialization. This parameter can accept either a
|
||||||
|
list of `Tool` objects or a `Toolset` instance.
|
||||||
:param streaming_callback:
|
:param streaming_callback:
|
||||||
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
||||||
parameter set during component initialization.
|
parameter set during component initialization.
|
||||||
@ -248,7 +250,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
tools = tools or self.tools
|
tools = tools or self.tools
|
||||||
if tools and self.streaming_callback:
|
if tools and self.streaming_callback:
|
||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(list(tools or []))
|
||||||
|
|
||||||
# validate and select the streaming callback
|
# validate and select the streaming callback
|
||||||
streaming_callback = select_streaming_callback(
|
streaming_callback = select_streaming_callback(
|
||||||
@ -260,6 +262,8 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
|
|
||||||
hf_tools = None
|
hf_tools = None
|
||||||
if tools:
|
if tools:
|
||||||
|
if isinstance(tools, Toolset):
|
||||||
|
tools = list(tools)
|
||||||
hf_tools = [
|
hf_tools = [
|
||||||
ChatCompletionInputTool(
|
ChatCompletionInputTool(
|
||||||
function=ChatCompletionInputFunctionDefinition(
|
function=ChatCompletionInputFunctionDefinition(
|
||||||
@ -276,7 +280,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
self,
|
self,
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -290,8 +294,9 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
:param generation_kwargs:
|
:param generation_kwargs:
|
||||||
Additional keyword arguments for text generation.
|
Additional keyword arguments for text generation.
|
||||||
:param tools:
|
:param tools:
|
||||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
|
||||||
during component initialization.
|
parameter set during component initialization. This parameter can accept either a list of `Tool` objects
|
||||||
|
or a `Toolset` instance.
|
||||||
:param streaming_callback:
|
:param streaming_callback:
|
||||||
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
||||||
parameter set during component initialization.
|
parameter set during component initialization.
|
||||||
@ -307,7 +312,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
tools = tools or self.tools
|
tools = tools or self.tools
|
||||||
if tools and self.streaming_callback:
|
if tools and self.streaming_callback:
|
||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(list(tools or []))
|
||||||
|
|
||||||
# validate and select the streaming callback
|
# validate and select the streaming callback
|
||||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
||||||
@ -317,6 +322,8 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
|
|
||||||
hf_tools = None
|
hf_tools = None
|
||||||
if tools:
|
if tools:
|
||||||
|
if isinstance(tools, Toolset):
|
||||||
|
tools = list(tools)
|
||||||
hf_tools = [
|
hf_tools = [
|
||||||
ChatCompletionInputTool(
|
ChatCompletionInputTool(
|
||||||
function=ChatCompletionInputFunctionDefinition(
|
function=ChatCompletionInputFunctionDefinition(
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
|
|||||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
from haystack.utils import (
|
from haystack.utils import (
|
||||||
ComponentDevice,
|
ComponentDevice,
|
||||||
Secret,
|
Secret,
|
||||||
@ -20,6 +21,7 @@ from haystack.utils import (
|
|||||||
deserialize_secrets_inplace,
|
deserialize_secrets_inplace,
|
||||||
serialize_callable,
|
serialize_callable,
|
||||||
)
|
)
|
||||||
|
from haystack.utils.misc import serialize_tools_or_toolset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -123,7 +125,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stop_words: Optional[List[str]] = None,
|
stop_words: Optional[List[str]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||||
):
|
):
|
||||||
@ -164,7 +166,8 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
For some chat models, the output includes both the new text and the original prompt.
|
For some chat models, the output includes both the new text and the original prompt.
|
||||||
In these cases, make sure your prompt has no stop words.
|
In these cases, make sure your prompt has no stop words.
|
||||||
:param streaming_callback: An optional callable for handling streaming responses.
|
:param streaming_callback: An optional callable for handling streaming responses.
|
||||||
:param tools: A list of tools for which the model can prepare calls.
|
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||||
|
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||||
:param tool_parsing_function:
|
:param tool_parsing_function:
|
||||||
A callable that takes a string and returns a list of ToolCall objects or None.
|
A callable that takes a string and returns a list of ToolCall objects or None.
|
||||||
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
|
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
|
||||||
@ -176,7 +179,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
|
|
||||||
if tools and streaming_callback is not None:
|
if tools and streaming_callback is not None:
|
||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(list(tools or []))
|
||||||
|
|
||||||
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
||||||
generation_kwargs = generation_kwargs or {}
|
generation_kwargs = generation_kwargs or {}
|
||||||
@ -273,7 +276,6 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
Dictionary with serialized data.
|
Dictionary with serialized data.
|
||||||
"""
|
"""
|
||||||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
||||||
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
|
|
||||||
serialization_dict = default_to_dict(
|
serialization_dict = default_to_dict(
|
||||||
self,
|
self,
|
||||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||||
@ -281,7 +283,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
streaming_callback=callback_name,
|
streaming_callback=callback_name,
|
||||||
token=self.token.to_dict() if self.token else None,
|
token=self.token.to_dict() if self.token else None,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tools=serialized_tools,
|
tools=serialize_tools_or_toolset(self.tools),
|
||||||
tool_parsing_function=serialize_callable(self.tool_parsing_function),
|
tool_parsing_function=serialize_callable(self.tool_parsing_function),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -323,7 +325,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Invoke text generation inference based on the provided messages and generation parameters.
|
Invoke text generation inference based on the provided messages and generation parameters.
|
||||||
@ -332,8 +334,9 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||||
:param streaming_callback: An optional callable for handling streaming responses.
|
:param streaming_callback: An optional callable for handling streaming responses.
|
||||||
:param tools:
|
:param tools:
|
||||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter
|
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||||
provided during initialization.
|
the `tools` parameter provided during initialization. This parameter can accept either a list
|
||||||
|
of `Tool` objects or a `Toolset` instance.
|
||||||
:returns:
|
:returns:
|
||||||
A list containing the generated responses as ChatMessage instances.
|
A list containing the generated responses as ChatMessage instances.
|
||||||
"""
|
"""
|
||||||
@ -377,6 +380,10 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
|
|
||||||
# convert messages to HF format
|
# convert messages to HF format
|
||||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||||
|
|
||||||
|
if isinstance(tools, Toolset):
|
||||||
|
tools = list(tools)
|
||||||
|
|
||||||
prepared_prompt = tokenizer.apply_chat_template(
|
prepared_prompt = tokenizer.apply_chat_template(
|
||||||
hf_messages,
|
hf_messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
@ -480,7 +487,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
||||||
@ -491,7 +498,8 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
:param messages: A list of ChatMessage objects representing the input messages.
|
:param messages: A list of ChatMessage objects representing the input messages.
|
||||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||||
:param streaming_callback: An optional callable for handling streaming responses.
|
:param streaming_callback: An optional callable for handling streaming responses.
|
||||||
:param tools: A list of tools for which the model can prepare calls.
|
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||||
|
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||||
:returns: A dictionary with the following keys:
|
:returns: A dictionary with the following keys:
|
||||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||||
"""
|
"""
|
||||||
@ -576,13 +584,17 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||||
generation_kwargs: Dict[str, Any],
|
generation_kwargs: Dict[str, Any],
|
||||||
stop_words: Optional[List[str]],
|
stop_words: Optional[List[str]],
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handles async non-streaming generation of responses.
|
Handles async non-streaming generation of responses.
|
||||||
"""
|
"""
|
||||||
# convert messages to HF format
|
# convert messages to HF format
|
||||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||||
|
|
||||||
|
if isinstance(tools, Toolset):
|
||||||
|
tools = list(tools)
|
||||||
|
|
||||||
prepared_prompt = tokenizer.apply_chat_template(
|
prepared_prompt = tokenizer.apply_chat_template(
|
||||||
hf_messages,
|
hf_messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add support for initializing chat generators with a Toolset, allowing for more flexible tool management. The tools parameter can now accept either a list of Tool objects or a Toolset instance.
|
||||||
@ -11,10 +11,16 @@ from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
|||||||
from haystack.components.generators.utils import print_streaming_chunk
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
from haystack.dataclasses import ChatMessage, ToolCall
|
||||||
from haystack.tools.tool import Tool
|
from haystack.tools.tool import Tool
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
from haystack.utils.azure import default_azure_ad_token_provider
|
from haystack.utils.azure import default_azure_ad_token_provider
|
||||||
|
|
||||||
|
|
||||||
|
def get_weather(city: str) -> str:
|
||||||
|
"""Get weather information for a city."""
|
||||||
|
return f"Weather info for {city}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tools():
|
def tools():
|
||||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||||
@ -22,7 +28,7 @@ def tools():
|
|||||||
name="weather",
|
name="weather",
|
||||||
description="useful to determine the weather in a given location",
|
description="useful to determine the weather in a given location",
|
||||||
parameters=tool_parameters,
|
parameters=tool_parameters,
|
||||||
function=lambda x: x,
|
function=get_weather,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [tool]
|
return [tool]
|
||||||
@ -228,6 +234,26 @@ class TestAzureOpenAIChatGenerator:
|
|||||||
q = Pipeline.loads(p_str)
|
q = Pipeline.loads(p_str)
|
||||||
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
|
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
|
||||||
|
|
||||||
|
def test_azure_chat_generator_with_toolset_initialization(self, tools, monkeypatch):
|
||||||
|
"""Test that the AzureOpenAIChatGenerator can be initialized with a Toolset."""
|
||||||
|
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||||
|
assert generator.tools == toolset
|
||||||
|
|
||||||
|
def test_from_dict_with_toolset(self, tools, monkeypatch):
|
||||||
|
"""Test that the AzureOpenAIChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||||
|
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||||
|
data = component.to_dict()
|
||||||
|
|
||||||
|
deserialized_component = AzureOpenAIChatGenerator.from_dict(data)
|
||||||
|
|
||||||
|
assert isinstance(deserialized_component.tools, Toolset)
|
||||||
|
assert len(deserialized_component.tools) == len(tools)
|
||||||
|
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||||
@ -272,7 +298,37 @@ class TestAzureOpenAIChatGenerator:
|
|||||||
assert tool_call.arguments == {"city": "Paris"}
|
assert tool_call.arguments == {"city": "Paris"}
|
||||||
assert message.meta["finish_reason"] == "tool_calls"
|
assert message.meta["finish_reason"] == "tool_calls"
|
||||||
|
|
||||||
# additional tests intentionally omitted as they are covered by test_openai.py
|
def test_to_dict_with_toolset(self, tools, monkeypatch):
|
||||||
|
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||||
|
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||||
|
data = component.to_dict()
|
||||||
|
|
||||||
|
expected_tools_data = {
|
||||||
|
"type": "haystack.tools.toolset.Toolset",
|
||||||
|
"data": {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "haystack.tools.tool.Tool",
|
||||||
|
"data": {
|
||||||
|
"name": "weather",
|
||||||
|
"description": "useful to determine the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"function": "generators.chat.test_azure.get_weather",
|
||||||
|
"outputs_to_string": None,
|
||||||
|
"inputs_from_state": None,
|
||||||
|
"outputs_to_state": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert data["init_parameters"]["tools"] == expected_tools_data
|
||||||
|
|
||||||
|
|
||||||
class TestAzureOpenAIChatGeneratorAsync:
|
class TestAzureOpenAIChatGeneratorAsync:
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from huggingface_hub.utils import RepositoryNotFoundError
|
|||||||
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
|
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
|
||||||
from haystack.tools import Tool
|
from haystack.tools import Tool
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
from haystack.dataclasses import ChatMessage, ToolCall
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -36,6 +37,11 @@ def chat_messages():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_weather(city: str) -> str:
|
||||||
|
"""Get weather information for a city."""
|
||||||
|
return f"Weather info for {city}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tools():
|
def tools():
|
||||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||||
@ -43,7 +49,7 @@ def tools():
|
|||||||
name="weather",
|
name="weather",
|
||||||
description="useful to determine the weather in a given location",
|
description="useful to determine the weather in a given location",
|
||||||
parameters=tool_parameters,
|
parameters=tool_parameters,
|
||||||
function=lambda x: x,
|
function=get_weather,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [tool]
|
return [tool]
|
||||||
@ -851,3 +857,58 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
||||||
finally:
|
finally:
|
||||||
await generator._async_client.close()
|
await generator._async_client.close()
|
||||||
|
|
||||||
|
def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools):
|
||||||
|
"""Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset."""
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||||
|
)
|
||||||
|
assert generator.tools == toolset
|
||||||
|
|
||||||
|
def test_from_dict_with_toolset(self, mock_check_valid_model, tools):
|
||||||
|
"""Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
component = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||||
|
)
|
||||||
|
data = component.to_dict()
|
||||||
|
|
||||||
|
deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data)
|
||||||
|
|
||||||
|
assert isinstance(deserialized_component.tools, Toolset)
|
||||||
|
assert len(deserialized_component.tools) == len(tools)
|
||||||
|
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||||
|
|
||||||
|
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
|
||||||
|
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||||
|
)
|
||||||
|
data = generator.to_dict()
|
||||||
|
|
||||||
|
expected_tools_data = {
|
||||||
|
"type": "haystack.tools.toolset.Toolset",
|
||||||
|
"data": {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "haystack.tools.tool.Tool",
|
||||||
|
"data": {
|
||||||
|
"name": "weather",
|
||||||
|
"description": "useful to determine the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"function": "generators.chat.test_hugging_face_api.get_weather",
|
||||||
|
"outputs_to_string": None,
|
||||||
|
"inputs_from_state": None,
|
||||||
|
"outputs_to_state": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert data["init_parameters"]["tools"] == expected_tools_data
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
|
|||||||
from haystack.tools import Tool
|
from haystack.tools import Tool
|
||||||
from haystack.utils import ComponentDevice
|
from haystack.utils import ComponentDevice
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
|
|
||||||
|
|
||||||
# used to test serialization of streaming_callback
|
# used to test serialization of streaming_callback
|
||||||
@ -94,7 +95,7 @@ class TestHuggingFaceLocalChatGenerator:
|
|||||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
||||||
assert generator.streaming_callback == streaming_callback
|
assert generator.streaming_callback == streaming_callback
|
||||||
|
|
||||||
def test_init_custom_token(self):
|
def test_init_custom_token(self, model_info_mock):
|
||||||
generator = HuggingFaceLocalChatGenerator(
|
generator = HuggingFaceLocalChatGenerator(
|
||||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
task="text2text-generation",
|
task="text2text-generation",
|
||||||
@ -109,7 +110,7 @@ class TestHuggingFaceLocalChatGenerator:
|
|||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_init_custom_device(self):
|
def test_init_custom_device(self, model_info_mock):
|
||||||
generator = HuggingFaceLocalChatGenerator(
|
generator = HuggingFaceLocalChatGenerator(
|
||||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
task="text2text-generation",
|
task="text2text-generation",
|
||||||
@ -124,7 +125,7 @@ class TestHuggingFaceLocalChatGenerator:
|
|||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_init_task_parameter(self):
|
def test_init_task_parameter(self, model_info_mock):
|
||||||
generator = HuggingFaceLocalChatGenerator(
|
generator = HuggingFaceLocalChatGenerator(
|
||||||
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
|
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
|
||||||
)
|
)
|
||||||
@ -136,7 +137,7 @@ class TestHuggingFaceLocalChatGenerator:
|
|||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_init_task_in_huggingface_pipeline_kwargs(self):
|
def test_init_task_in_huggingface_pipeline_kwargs(self, model_info_mock):
|
||||||
generator = HuggingFaceLocalChatGenerator(
|
generator = HuggingFaceLocalChatGenerator(
|
||||||
huggingface_pipeline_kwargs={"task": "text2text-generation"},
|
huggingface_pipeline_kwargs={"task": "text2text-generation"},
|
||||||
device=ComponentDevice.from_str("cpu"),
|
device=ComponentDevice.from_str("cpu"),
|
||||||
@ -553,3 +554,56 @@ class TestHuggingFaceLocalChatGenerator:
|
|||||||
del generator
|
del generator
|
||||||
gc.collect()
|
gc.collect()
|
||||||
mock_shutdown.assert_called_once_with(wait=True)
|
mock_shutdown.assert_called_once_with(wait=True)
|
||||||
|
|
||||||
|
def test_hugging_face_local_generator_with_toolset_initialization(
|
||||||
|
self, model_info_mock, mock_pipeline_tokenizer, tools
|
||||||
|
):
|
||||||
|
"""Test that the HuggingFaceLocalChatGenerator can be initialized with a Toolset."""
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
generator = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
|
||||||
|
generator.pipeline = mock_pipeline_tokenizer
|
||||||
|
assert generator.tools == toolset
|
||||||
|
|
||||||
|
def test_from_dict_with_toolset(self, model_info_mock, tools):
|
||||||
|
"""Test that the HuggingFaceLocalChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
component = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
|
||||||
|
data = component.to_dict()
|
||||||
|
|
||||||
|
deserialized_component = HuggingFaceLocalChatGenerator.from_dict(data)
|
||||||
|
|
||||||
|
assert isinstance(deserialized_component.tools, Toolset)
|
||||||
|
assert len(deserialized_component.tools) == len(tools)
|
||||||
|
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||||
|
|
||||||
|
def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, tools):
|
||||||
|
"""Test that the HuggingFaceLocalChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||||
|
toolset = Toolset(tools)
|
||||||
|
generator = HuggingFaceLocalChatGenerator(huggingface_pipeline_kwargs={"model": "irrelevant"}, tools=toolset)
|
||||||
|
generator.pipeline = mock_pipeline_tokenizer
|
||||||
|
data = generator.to_dict()
|
||||||
|
|
||||||
|
expected_tools_data = {
|
||||||
|
"type": "haystack.tools.toolset.Toolset",
|
||||||
|
"data": {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "haystack.tools.tool.Tool",
|
||||||
|
"data": {
|
||||||
|
"name": "weather",
|
||||||
|
"description": "useful to determine the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"function": "generators.chat.test_hugging_face_local.get_weather",
|
||||||
|
"outputs_to_string": None,
|
||||||
|
"inputs_from_state": None,
|
||||||
|
"outputs_to_state": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert data["init_parameters"]["tools"] == expected_tools_data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user