From e1e797206d05912a2e41ee40d05a42518686f11f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 7 Apr 2025 16:12:09 +0200 Subject: [PATCH] feat: Add Toolset support in ChatGenerator(s) (#9177) * Add Toolset support in ChatGenerator(s) * Add reno note * Update azure test * Updates * Minor fix * Add more tests * Remove some integration tests * PR feedback * rm unused fixture --------- Co-authored-by: anakin87 --- haystack/components/generators/chat/azure.py | 11 ++-- .../generators/chat/hugging_face_api.py | 35 ++++++----- .../generators/chat/hugging_face_local.py | 34 ++++++---- ...olset-initialization-1ccbc174abf16bd2.yaml | 4 ++ test/components/generators/chat/test_azure.py | 60 +++++++++++++++++- .../generators/chat/test_hugging_face_api.py | 63 ++++++++++++++++++- .../chat/test_hugging_face_local.py | 62 ++++++++++++++++-- 7 files changed, 233 insertions(+), 36 deletions(-) create mode 100644 releasenotes/notes/support-toolset-initialization-1ccbc174abf16bd2.yaml diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index add5abfc3..0fa93a089 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -11,7 +11,9 @@ from haystack import component, default_from_dict, default_to_dict from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.toolset import Toolset from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable +from haystack.utils.misc import serialize_tools_or_toolset @component @@ -73,7 +75,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): max_retries: Optional[int] = None, generation_kwargs: Optional[Dict[str, Any]] = None, default_headers: Optional[Dict[str, str]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, tools_strict: bool = False, *, azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None, @@ -115,7 +117,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): values are the bias to add to that token. :param default_headers: Default headers to use for the AzureOpenAI client. :param tools: - A list of tools for which the model can prepare calls. + A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a + list of `Tool` objects or a `Toolset` instance. :param tools_strict: Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly the schema provided in the `parameters` field of the tool definition, but this may increase latency. @@ -152,7 +155,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider - _check_duplicate_tool_names(tools) + _check_duplicate_tool_names(list(tools or [])) self.tools = tools self.tools_strict = tools_strict @@ -196,7 +199,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): api_key=self.api_key.to_dict() if self.api_key is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, default_headers=self.default_headers, - tools=[tool.to_dict() for tool in self.tools] if self.tools else None, + tools=serialize_tools_or_toolset(self.tools), tools_strict=self.tools_strict, azure_ad_token_provider=azure_ad_token_provider_name, ) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 09488f0a5..3889a5224 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -10,8 +10,10 @@ from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_s from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.lazy_imports import LazyImport from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.toolset import Toolset from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format +from haystack.utils.misc import serialize_tools_or_toolset from haystack.utils.url_validation import is_valid_http_url with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: @@ -103,7 +105,7 @@ class HuggingFaceAPIChatGenerator: generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[StreamingCallbackT] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, ): """ Initialize the HuggingFaceAPIChatGenerator instance. @@ -130,10 +132,10 @@ class HuggingFaceAPIChatGenerator: :param streaming_callback: An optional callable for handling streaming responses. :param tools: - A list of tools for which the model can prepare calls. + A list of tools or a Toolset for which the model can prepare calls. The chosen model should support tool/function calling, according to the model card. Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience - unexpected behavior. + unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance. """ huggingface_hub_import.check() @@ -166,7 +168,7 @@ class HuggingFaceAPIChatGenerator: if tools and streaming_callback is not None: raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") - _check_duplicate_tool_names(tools) + _check_duplicate_tool_names(list(tools or [])) # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} @@ -191,7 +193,6 @@ class HuggingFaceAPIChatGenerator: A dictionary containing the serialized component. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None - serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, api_type=str(self.api_type), @@ -199,7 +200,7 @@ class HuggingFaceAPIChatGenerator: token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, - tools=serialized_tools, + tools=serialize_tools_or_toolset(self.tools), ) @classmethod @@ -220,7 +221,7 @@ class HuggingFaceAPIChatGenerator: self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, streaming_callback: Optional[StreamingCallbackT] = None, ): """ @@ -231,8 +232,9 @@ class HuggingFaceAPIChatGenerator: :param generation_kwargs: Additional keyword arguments for text generation. :param tools: - A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set - during component initialization. + A list of tools or a Toolset for which the model can prepare calls. If set, it will override + the `tools` parameter set during component initialization. This parameter can accept either a + list of `Tool` objects or a `Toolset` instance. :param streaming_callback: An optional callable for handling streaming responses. If set, it will override the `streaming_callback` parameter set during component initialization. @@ -248,7 +250,7 @@ class HuggingFaceAPIChatGenerator: tools = tools or self.tools if tools and self.streaming_callback: raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") - _check_duplicate_tool_names(tools) + _check_duplicate_tool_names(list(tools or [])) # validate and select the streaming callback streaming_callback = select_streaming_callback( @@ -260,6 +262,8 @@ class HuggingFaceAPIChatGenerator: hf_tools = None if tools: + if isinstance(tools, Toolset): + tools = list(tools) hf_tools = [ ChatCompletionInputTool( function=ChatCompletionInputFunctionDefinition( @@ -276,7 +280,7 @@ class HuggingFaceAPIChatGenerator: self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, streaming_callback: Optional[StreamingCallbackT] = None, ): """ @@ -290,8 +294,9 @@ class HuggingFaceAPIChatGenerator: :param generation_kwargs: Additional keyword arguments for text generation. :param tools: - A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set - during component initialization. + A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools` + parameter set during component initialization. This parameter can accept either a list of `Tool` objects + or a `Toolset` instance. :param streaming_callback: An optional callable for handling streaming responses. If set, it will override the `streaming_callback` parameter set during component initialization. @@ -307,7 +312,7 @@ class HuggingFaceAPIChatGenerator: tools = tools or self.tools if tools and self.streaming_callback: raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") - _check_duplicate_tool_names(tools) + _check_duplicate_tool_names(list(tools or [])) # validate and select the streaming callback streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) @@ -317,6 +322,8 @@ class HuggingFaceAPIChatGenerator: hf_tools = None if tools: + if isinstance(tools, Toolset): + tools = list(tools) hf_tools = [ ChatCompletionInputTool( function=ChatCompletionInputFunctionDefinition( diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index e4ad4ea80..d32f03966 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -13,6 +13,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.toolset import Toolset from haystack.utils import ( ComponentDevice, Secret, @@ -20,6 +21,7 @@ from haystack.utils import ( deserialize_secrets_inplace, serialize_callable, ) +from haystack.utils.misc import serialize_tools_or_toolset logger = logging.getLogger(__name__) @@ -123,7 +125,7 @@ class HuggingFaceLocalChatGenerator: huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, async_executor: Optional[ThreadPoolExecutor] = None, ): @@ -164,7 +166,8 @@ class HuggingFaceLocalChatGenerator: For some chat models, the output includes both the new text and the original prompt. In these cases, make sure your prompt has no stop words. :param streaming_callback: An optional callable for handling streaming responses. - :param tools: A list of tools for which the model can prepare calls. + :param tools: A list of tools or a Toolset for which the model can prepare calls. + This parameter can accept either a list of `Tool` objects or a `Toolset` instance. :param tool_parsing_function: A callable that takes a string and returns a list of ToolCall objects or None. If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern. @@ -176,7 +179,7 @@ class HuggingFaceLocalChatGenerator: if tools and streaming_callback is not None: raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") - _check_duplicate_tool_names(tools) + _check_duplicate_tool_names(list(tools or [])) huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} generation_kwargs = generation_kwargs or {} @@ -273,7 +276,6 @@ class HuggingFaceLocalChatGenerator: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None - serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None serialization_dict = default_to_dict( self, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, @@ -281,7 +283,7 @@ class HuggingFaceLocalChatGenerator: streaming_callback=callback_name, token=self.token.to_dict() if self.token else None, chat_template=self.chat_template, - tools=serialized_tools, + tools=serialize_tools_or_toolset(self.tools), tool_parsing_function=serialize_callable(self.tool_parsing_function), ) @@ -323,7 +325,7 @@ class HuggingFaceLocalChatGenerator: messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, ): """ Invoke text generation inference based on the provided messages and generation parameters. @@ -332,8 +334,9 @@ class HuggingFaceLocalChatGenerator: :param generation_kwargs: Additional keyword arguments for text generation. :param streaming_callback: An optional callable for handling streaming responses. :param tools: - A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter - provided during initialization. + A list of tools or a Toolset for which the model can prepare calls. If set, it will override + the `tools` parameter provided during initialization. This parameter can accept either a list + of `Tool` objects or a `Toolset` instance. :returns: A list containing the generated responses as ChatMessage instances. """ @@ -377,6 +380,10 @@ class HuggingFaceLocalChatGenerator: # convert messages to HF format hf_messages = [convert_message_to_hf_format(message) for message in messages] + + if isinstance(tools, Toolset): + tools = list(tools) + prepared_prompt = tokenizer.apply_chat_template( hf_messages, tokenize=False, @@ -480,7 +487,7 @@ class HuggingFaceLocalChatGenerator: messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, ): """ Asynchronously invokes text generation inference based on the provided messages and generation parameters. @@ -491,7 +498,8 @@ class HuggingFaceLocalChatGenerator: :param messages: A list of ChatMessage objects representing the input messages. :param generation_kwargs: Additional keyword arguments for text generation. :param streaming_callback: An optional callable for handling streaming responses. - :param tools: A list of tools for which the model can prepare calls. + :param tools: A list of tools or a Toolset for which the model can prepare calls. + This parameter can accept either a list of `Tool` objects or a `Toolset` instance. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage instances. """ @@ -576,13 +584,17 @@ class HuggingFaceLocalChatGenerator: tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], generation_kwargs: Dict[str, Any], stop_words: Optional[List[str]], - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, ): """ Handles async non-streaming generation of responses. """ # convert messages to HF format hf_messages = [convert_message_to_hf_format(message) for message in messages] + + if isinstance(tools, Toolset): + tools = list(tools) + prepared_prompt = tokenizer.apply_chat_template( hf_messages, tokenize=False, diff --git a/releasenotes/notes/support-toolset-initialization-1ccbc174abf16bd2.yaml b/releasenotes/notes/support-toolset-initialization-1ccbc174abf16bd2.yaml new file mode 100644 index 000000000..2c700548b --- /dev/null +++ b/releasenotes/notes/support-toolset-initialization-1ccbc174abf16bd2.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for initializing chat generators with a Toolset, allowing for more flexible tool management. The tools parameter can now accept either a list of Tool objects or a Toolset instance. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index d1f91b15e..4e01cbfd7 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -11,10 +11,16 @@ from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools.tool import Tool +from haystack.tools.toolset import Toolset from haystack.utils.auth import Secret from haystack.utils.azure import default_azure_ad_token_provider +def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"Weather info for {city}" + + @pytest.fixture def tools(): tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} @@ -22,7 +28,7 @@ def tools(): name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, - function=lambda x: x, + function=get_weather, ) return [tool] @@ -228,6 +234,26 @@ class TestAzureOpenAIChatGenerator: q = Pipeline.loads(p_str) assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed." + def test_azure_chat_generator_with_toolset_initialization(self, tools, monkeypatch): + """Test that the AzureOpenAIChatGenerator can be initialized with a Toolset.""" + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + toolset = Toolset(tools) + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset) + assert generator.tools == toolset + + def test_from_dict_with_toolset(self, tools, monkeypatch): + """Test that the AzureOpenAIChatGenerator can be deserialized from a dictionary with a Toolset.""" + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + toolset = Toolset(tools) + component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset) + data = component.to_dict() + + deserialized_component = AzureOpenAIChatGenerator.from_dict(data) + + assert isinstance(deserialized_component.tools, Toolset) + assert len(deserialized_component.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None), @@ -272,7 +298,37 @@ class TestAzureOpenAIChatGenerator: assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" - # additional tests intentionally omitted as they are covered by test_openai.py + def test_to_dict_with_toolset(self, tools, monkeypatch): + """Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset.""" + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + toolset = Toolset(tools) + component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset) + data = component.to_dict() + + expected_tools_data = { + "type": "haystack.tools.toolset.Toolset", + "data": { + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "function": "generators.chat.test_azure.get_weather", + "outputs_to_string": None, + "inputs_from_state": None, + "outputs_to_state": None, + }, + } + ] + }, + } + assert data["init_parameters"]["tools"] == expected_tools_data class TestAzureOpenAIChatGeneratorAsync: diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 39a62e408..c16997c9e 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -26,6 +26,7 @@ from huggingface_hub.utils import RepositoryNotFoundError from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator from haystack.tools import Tool from haystack.dataclasses import ChatMessage, ToolCall +from haystack.tools.toolset import Toolset @pytest.fixture @@ -36,6 +37,11 @@ def chat_messages(): ] +def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"Weather info for {city}" + + @pytest.fixture def tools(): tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} @@ -43,7 +49,7 @@ def tools(): name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, - function=lambda x: x, + function=get_weather, ) return [tool] @@ -851,3 +857,58 @@ class TestHuggingFaceAPIChatGenerator: assert "completion_tokens" in response["replies"][0].meta["usage"] finally: await generator._async_client.close() + + def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools): + """Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset.""" + toolset = Toolset(tools) + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset + ) + assert generator.tools == toolset + + def test_from_dict_with_toolset(self, mock_check_valid_model, tools): + """Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset.""" + toolset = Toolset(tools) + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset + ) + data = component.to_dict() + + deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data) + + assert isinstance(deserialized_component.tools, Toolset) + assert len(deserialized_component.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) + + def test_to_dict_with_toolset(self, mock_check_valid_model, tools): + """Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset.""" + toolset = Toolset(tools) + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset + ) + data = generator.to_dict() + + expected_tools_data = { + "type": "haystack.tools.toolset.Toolset", + "data": { + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "function": "generators.chat.test_hugging_face_api.get_weather", + "outputs_to_string": None, + "inputs_from_state": None, + "outputs_to_state": None, + }, + } + ] + }, + } + assert data["init_parameters"]["tools"] == expected_tools_data diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index c44e73bbe..fa3893c93 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -16,6 +16,7 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool from haystack.utils import ComponentDevice from haystack.utils.auth import Secret +from haystack.tools.toolset import Toolset # used to test serialization of streaming_callback @@ -94,7 +95,7 @@ class TestHuggingFaceLocalChatGenerator: assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} assert generator.streaming_callback == streaming_callback - def test_init_custom_token(self): + def test_init_custom_token(self, model_info_mock): generator = HuggingFaceLocalChatGenerator( model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", @@ -109,7 +110,7 @@ class TestHuggingFaceLocalChatGenerator: "device": "cpu", } - def test_init_custom_device(self): + def test_init_custom_device(self, model_info_mock): generator = HuggingFaceLocalChatGenerator( model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", @@ -124,7 +125,7 @@ class TestHuggingFaceLocalChatGenerator: "device": "cpu", } - def test_init_task_parameter(self): + def test_init_task_parameter(self, model_info_mock): generator = HuggingFaceLocalChatGenerator( task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None ) @@ -136,7 +137,7 @@ class TestHuggingFaceLocalChatGenerator: "device": "cpu", } - def test_init_task_in_huggingface_pipeline_kwargs(self): + def test_init_task_in_huggingface_pipeline_kwargs(self, model_info_mock): generator = HuggingFaceLocalChatGenerator( huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu"), @@ -553,3 +554,56 @@ class TestHuggingFaceLocalChatGenerator: del generator gc.collect() mock_shutdown.assert_called_once_with(wait=True) + + def test_hugging_face_local_generator_with_toolset_initialization( + self, model_info_mock, mock_pipeline_tokenizer, tools + ): + """Test that the HuggingFaceLocalChatGenerator can be initialized with a Toolset.""" + toolset = Toolset(tools) + generator = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset) + generator.pipeline = mock_pipeline_tokenizer + assert generator.tools == toolset + + def test_from_dict_with_toolset(self, model_info_mock, tools): + """Test that the HuggingFaceLocalChatGenerator can be deserialized from a dictionary with a Toolset.""" + toolset = Toolset(tools) + component = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset) + data = component.to_dict() + + deserialized_component = HuggingFaceLocalChatGenerator.from_dict(data) + + assert isinstance(deserialized_component.tools, Toolset) + assert len(deserialized_component.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) + + def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, tools): + """Test that the HuggingFaceLocalChatGenerator can be serialized to a dictionary with a Toolset.""" + toolset = Toolset(tools) + generator = HuggingFaceLocalChatGenerator(huggingface_pipeline_kwargs={"model": "irrelevant"}, tools=toolset) + generator.pipeline = mock_pipeline_tokenizer + data = generator.to_dict() + + expected_tools_data = { + "type": "haystack.tools.toolset.Toolset", + "data": { + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "function": "generators.chat.test_hugging_face_local.get_weather", + "outputs_to_string": None, + "inputs_from_state": None, + "outputs_to_state": None, + }, + } + ] + }, + } + assert data["init_parameters"]["tools"] == expected_tools_data