mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-15 17:43:55 +00:00
feat: Add Toolset support in ChatGenerator(s) (#9177)
* Add Toolset support in ChatGenerator(s) * Add reno note * Update azure test * Updates * Minor fix * Add more tests * Remove some integration tests * PR feedback * rm unused fixture --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
parent
63781afd8f
commit
e1e797206d
@ -11,7 +11,9 @@ from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.misc import serialize_tools_or_toolset
|
||||
|
||||
|
||||
@component
|
||||
@ -73,7 +75,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
max_retries: Optional[int] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tools_strict: bool = False,
|
||||
*,
|
||||
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
|
||||
@ -115,7 +117,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
values are the bias to add to that token.
|
||||
:param default_headers: Default headers to use for the AzureOpenAI client.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls.
|
||||
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
|
||||
list of `Tool` objects or a `Toolset` instance.
|
||||
:param tools_strict:
|
||||
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||
@ -152,7 +155,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
self.default_headers = default_headers or {}
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
|
||||
_check_duplicate_tool_names(tools)
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
self.tools = tools
|
||||
self.tools_strict = tools_strict
|
||||
|
||||
@ -196,7 +199,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
default_headers=self.default_headers,
|
||||
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
tools_strict=self.tools_strict,
|
||||
azure_ad_token_provider=azure_ad_token_provider_name,
|
||||
)
|
||||
|
||||
@ -10,8 +10,10 @@ from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_s
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
|
||||
from haystack.utils.misc import serialize_tools_or_toolset
|
||||
from haystack.utils.url_validation import is_valid_http_url
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
|
||||
@ -103,7 +105,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceAPIChatGenerator instance.
|
||||
@ -130,10 +132,10 @@ class HuggingFaceAPIChatGenerator:
|
||||
:param streaming_callback:
|
||||
An optional callable for handling streaming responses.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls.
|
||||
A list of tools or a Toolset for which the model can prepare calls.
|
||||
The chosen model should support tool/function calling, according to the model card.
|
||||
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
|
||||
unexpected behavior.
|
||||
unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||
"""
|
||||
|
||||
huggingface_hub_import.check()
|
||||
@ -166,7 +168,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
|
||||
if tools and streaming_callback is not None:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
|
||||
# handle generation kwargs setup
|
||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||
@ -191,7 +193,6 @@ class HuggingFaceAPIChatGenerator:
|
||||
A dictionary containing the serialized component.
|
||||
"""
|
||||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
||||
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
api_type=str(self.api_type),
|
||||
@ -199,7 +200,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
tools=serialized_tools,
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -220,7 +221,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
):
|
||||
"""
|
||||
@ -231,8 +232,9 @@ class HuggingFaceAPIChatGenerator:
|
||||
:param generation_kwargs:
|
||||
Additional keyword arguments for text generation.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
||||
during component initialization.
|
||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||
the `tools` parameter set during component initialization. This parameter can accept either a
|
||||
list of `Tool` objects or a `Toolset` instance.
|
||||
:param streaming_callback:
|
||||
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
||||
parameter set during component initialization.
|
||||
@ -248,7 +250,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
tools = tools or self.tools
|
||||
if tools and self.streaming_callback:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
|
||||
# validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(
|
||||
@ -260,6 +262,8 @@ class HuggingFaceAPIChatGenerator:
|
||||
|
||||
hf_tools = None
|
||||
if tools:
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
hf_tools = [
|
||||
ChatCompletionInputTool(
|
||||
function=ChatCompletionInputFunctionDefinition(
|
||||
@ -276,7 +280,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
):
|
||||
"""
|
||||
@ -290,8 +294,9 @@ class HuggingFaceAPIChatGenerator:
|
||||
:param generation_kwargs:
|
||||
Additional keyword arguments for text generation.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
||||
during component initialization.
|
||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
|
||||
parameter set during component initialization. This parameter can accept either a list of `Tool` objects
|
||||
or a `Toolset` instance.
|
||||
:param streaming_callback:
|
||||
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
||||
parameter set during component initialization.
|
||||
@ -307,7 +312,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
tools = tools or self.tools
|
||||
if tools and self.streaming_callback:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
|
||||
# validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
||||
@ -317,6 +322,8 @@ class HuggingFaceAPIChatGenerator:
|
||||
|
||||
hf_tools = None
|
||||
if tools:
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
hf_tools = [
|
||||
ChatCompletionInputTool(
|
||||
function=ChatCompletionInputFunctionDefinition(
|
||||
|
||||
@ -13,6 +13,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils import (
|
||||
ComponentDevice,
|
||||
Secret,
|
||||
@ -20,6 +21,7 @@ from haystack.utils import (
|
||||
deserialize_secrets_inplace,
|
||||
serialize_callable,
|
||||
)
|
||||
from haystack.utils.misc import serialize_tools_or_toolset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -123,7 +125,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
@ -164,7 +166,8 @@ class HuggingFaceLocalChatGenerator:
|
||||
For some chat models, the output includes both the new text and the original prompt.
|
||||
In these cases, make sure your prompt has no stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
:param tools: A list of tools for which the model can prepare calls.
|
||||
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||
:param tool_parsing_function:
|
||||
A callable that takes a string and returns a list of ToolCall objects or None.
|
||||
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
|
||||
@ -176,7 +179,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
|
||||
if tools and streaming_callback is not None:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
|
||||
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
||||
generation_kwargs = generation_kwargs or {}
|
||||
@ -273,7 +276,6 @@ class HuggingFaceLocalChatGenerator:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
||||
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
|
||||
serialization_dict = default_to_dict(
|
||||
self,
|
||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||
@ -281,7 +283,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
streaming_callback=callback_name,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
chat_template=self.chat_template,
|
||||
tools=serialized_tools,
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
tool_parsing_function=serialize_callable(self.tool_parsing_function),
|
||||
)
|
||||
|
||||
@ -323,7 +325,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
Invoke text generation inference based on the provided messages and generation parameters.
|
||||
@ -332,8 +334,9 @@ class HuggingFaceLocalChatGenerator:
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter
|
||||
provided during initialization.
|
||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||
the `tools` parameter provided during initialization. This parameter can accept either a list
|
||||
of `Tool` objects or a `Toolset` instance.
|
||||
:returns:
|
||||
A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
@ -377,6 +380,10 @@ class HuggingFaceLocalChatGenerator:
|
||||
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages,
|
||||
tokenize=False,
|
||||
@ -480,7 +487,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
||||
@ -491,7 +498,8 @@ class HuggingFaceLocalChatGenerator:
|
||||
:param messages: A list of ChatMessage objects representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
:param tools: A list of tools for which the model can prepare calls.
|
||||
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||
:returns: A dictionary with the following keys:
|
||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
@ -576,13 +584,17 @@ class HuggingFaceLocalChatGenerator:
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
Handles async non-streaming generation of responses.
|
||||
"""
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages,
|
||||
tokenize=False,
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add support for initializing chat generators with a Toolset, allowing for more flexible tool management. The tools parameter can now accept either a list of Tool objects or a Toolset instance.
|
||||
@ -11,10 +11,16 @@ from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools.tool import Tool
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.azure import default_azure_ad_token_provider
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get weather information for a city."""
|
||||
return f"Weather info for {city}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools():
|
||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
@ -22,7 +28,7 @@ def tools():
|
||||
name="weather",
|
||||
description="useful to determine the weather in a given location",
|
||||
parameters=tool_parameters,
|
||||
function=lambda x: x,
|
||||
function=get_weather,
|
||||
)
|
||||
|
||||
return [tool]
|
||||
@ -228,6 +234,26 @@ class TestAzureOpenAIChatGenerator:
|
||||
q = Pipeline.loads(p_str)
|
||||
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
|
||||
|
||||
def test_azure_chat_generator_with_toolset_initialization(self, tools, monkeypatch):
|
||||
"""Test that the AzureOpenAIChatGenerator can be initialized with a Toolset."""
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
toolset = Toolset(tools)
|
||||
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||
assert generator.tools == toolset
|
||||
|
||||
def test_from_dict_with_toolset(self, tools, monkeypatch):
|
||||
"""Test that the AzureOpenAIChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
toolset = Toolset(tools)
|
||||
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||
data = component.to_dict()
|
||||
|
||||
deserialized_component = AzureOpenAIChatGenerator.from_dict(data)
|
||||
|
||||
assert isinstance(deserialized_component.tools, Toolset)
|
||||
assert len(deserialized_component.tools) == len(tools)
|
||||
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
@ -272,7 +298,37 @@ class TestAzureOpenAIChatGenerator:
|
||||
assert tool_call.arguments == {"city": "Paris"}
|
||||
assert message.meta["finish_reason"] == "tool_calls"
|
||||
|
||||
# additional tests intentionally omitted as they are covered by test_openai.py
|
||||
def test_to_dict_with_toolset(self, tools, monkeypatch):
|
||||
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
toolset = Toolset(tools)
|
||||
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
|
||||
data = component.to_dict()
|
||||
|
||||
expected_tools_data = {
|
||||
"type": "haystack.tools.toolset.Toolset",
|
||||
"data": {
|
||||
"tools": [
|
||||
{
|
||||
"type": "haystack.tools.tool.Tool",
|
||||
"data": {
|
||||
"name": "weather",
|
||||
"description": "useful to determine the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"function": "generators.chat.test_azure.get_weather",
|
||||
"outputs_to_string": None,
|
||||
"inputs_from_state": None,
|
||||
"outputs_to_state": None,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
assert data["init_parameters"]["tools"] == expected_tools_data
|
||||
|
||||
|
||||
class TestAzureOpenAIChatGeneratorAsync:
|
||||
|
||||
@ -26,6 +26,7 @@ from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
|
||||
from haystack.tools import Tool
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools.toolset import Toolset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -36,6 +37,11 @@ def chat_messages():
|
||||
]
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get weather information for a city."""
|
||||
return f"Weather info for {city}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools():
|
||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
@ -43,7 +49,7 @@ def tools():
|
||||
name="weather",
|
||||
description="useful to determine the weather in a given location",
|
||||
parameters=tool_parameters,
|
||||
function=lambda x: x,
|
||||
function=get_weather,
|
||||
)
|
||||
|
||||
return [tool]
|
||||
@ -851,3 +857,58 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
||||
finally:
|
||||
await generator._async_client.close()
|
||||
|
||||
def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools):
|
||||
"""Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||
)
|
||||
assert generator.tools == toolset
|
||||
|
||||
def test_from_dict_with_toolset(self, mock_check_valid_model, tools):
|
||||
"""Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
component = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data)
|
||||
|
||||
assert isinstance(deserialized_component.tools, Toolset)
|
||||
assert len(deserialized_component.tools) == len(tools)
|
||||
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||
|
||||
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
|
||||
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
||||
)
|
||||
data = generator.to_dict()
|
||||
|
||||
expected_tools_data = {
|
||||
"type": "haystack.tools.toolset.Toolset",
|
||||
"data": {
|
||||
"tools": [
|
||||
{
|
||||
"type": "haystack.tools.tool.Tool",
|
||||
"data": {
|
||||
"name": "weather",
|
||||
"description": "useful to determine the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"function": "generators.chat.test_hugging_face_api.get_weather",
|
||||
"outputs_to_string": None,
|
||||
"inputs_from_state": None,
|
||||
"outputs_to_state": None,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
assert data["init_parameters"]["tools"] == expected_tools_data
|
||||
|
||||
@ -16,6 +16,7 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
from haystack.tools import Tool
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.tools.toolset import Toolset
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
@ -94,7 +95,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
|
||||
def test_init_custom_token(self):
|
||||
def test_init_custom_token(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
@ -109,7 +110,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_custom_device(self):
|
||||
def test_init_custom_device(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
@ -124,7 +125,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_task_parameter(self):
|
||||
def test_init_task_parameter(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
|
||||
)
|
||||
@ -136,7 +137,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_task_in_huggingface_pipeline_kwargs(self):
|
||||
def test_init_task_in_huggingface_pipeline_kwargs(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
huggingface_pipeline_kwargs={"task": "text2text-generation"},
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
@ -553,3 +554,56 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
del generator
|
||||
gc.collect()
|
||||
mock_shutdown.assert_called_once_with(wait=True)
|
||||
|
||||
def test_hugging_face_local_generator_with_toolset_initialization(
|
||||
self, model_info_mock, mock_pipeline_tokenizer, tools
|
||||
):
|
||||
"""Test that the HuggingFaceLocalChatGenerator can be initialized with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
generator = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
assert generator.tools == toolset
|
||||
|
||||
def test_from_dict_with_toolset(self, model_info_mock, tools):
|
||||
"""Test that the HuggingFaceLocalChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
component = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
|
||||
data = component.to_dict()
|
||||
|
||||
deserialized_component = HuggingFaceLocalChatGenerator.from_dict(data)
|
||||
|
||||
assert isinstance(deserialized_component.tools, Toolset)
|
||||
assert len(deserialized_component.tools) == len(tools)
|
||||
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||
|
||||
def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, tools):
|
||||
"""Test that the HuggingFaceLocalChatGenerator can be serialized to a dictionary with a Toolset."""
|
||||
toolset = Toolset(tools)
|
||||
generator = HuggingFaceLocalChatGenerator(huggingface_pipeline_kwargs={"model": "irrelevant"}, tools=toolset)
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
data = generator.to_dict()
|
||||
|
||||
expected_tools_data = {
|
||||
"type": "haystack.tools.toolset.Toolset",
|
||||
"data": {
|
||||
"tools": [
|
||||
{
|
||||
"type": "haystack.tools.tool.Tool",
|
||||
"data": {
|
||||
"name": "weather",
|
||||
"description": "useful to determine the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"function": "generators.chat.test_hugging_face_local.get_weather",
|
||||
"outputs_to_string": None,
|
||||
"inputs_from_state": None,
|
||||
"outputs_to_state": None,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
assert data["init_parameters"]["tools"] == expected_tools_data
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user