# SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 from datetime import datetime import os from unittest.mock import MagicMock, Mock, AsyncMock, patch import pytest from haystack import Pipeline from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( ChatCompletionOutput, ChatCompletionOutputComplete, ChatCompletionOutputFunctionDefinition, ChatCompletionOutputMessage, ChatCompletionOutputToolCall, ChatCompletionOutputUsage, ChatCompletionStreamOutput, ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ) from huggingface_hub.utils import RepositoryNotFoundError from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator, _convert_hfapi_tool_calls from haystack.tools import Tool from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools.toolset import Toolset @pytest.fixture def chat_messages(): return [ ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), ChatMessage.from_user("Tell me about Berlin"), ] def get_weather(city: str) -> str: """Get weather information for a city.""" return f"Weather info for {city}" @pytest.fixture def tools(): tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool = Tool( name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, function=get_weather, ) return [tool] @pytest.fixture def mock_check_valid_model(): with patch( "haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None) ) as mock: yield mock @pytest.fixture def mock_chat_completion(): # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: completion = ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason="eos_token", index=0, message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), created=1710498360, ) mock_chat_completion.return_value = completion yield mock_chat_completion @pytest.fixture def mock_chat_completion_async(): with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion: completion = ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason="eos_token", index=0, message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), created=1710498360, ) # Use AsyncMock to properly mock the async method mock_chat_completion.return_value = completion mock_chat_completion.__call__ = AsyncMock(return_value=completion) yield mock_chat_completion # used to test serialization of streaming_callback def streaming_callback_handler(x): return x class TestHuggingFaceAPIChatGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) def test_init_serverless(self, mock_check_valid_model): model = "HuggingFaceH4/zephyr-7b-alpha" generation_kwargs = {"temperature": 0.6} stop_words = ["stop"] streaming_callback = None generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}, token=None, generation_kwargs=generation_kwargs, stop_words=stop_words, streaming_callback=streaming_callback, ) assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator.api_params == {"model": model} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback assert generator.tools is None # check that client and async_client are initialized assert generator._client.model == model assert generator._async_client.model == model def test_init_serverless_with_tools(self, mock_check_valid_model, tools): model = "HuggingFaceH4/zephyr-7b-alpha" generation_kwargs = {"temperature": 0.6} stop_words = ["stop"] streaming_callback = None generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}, token=None, generation_kwargs=generation_kwargs, stop_words=stop_words, streaming_callback=streaming_callback, tools=tools, ) assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator.api_params == {"model": model} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback assert generator.tools == tools assert generator._client.model == model assert generator._async_client.model == model def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError): HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} ) def test_init_serverless_no_model(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} ) def test_init_tgi(self): url = "https://some_model.com" generation_kwargs = {"temperature": 0.6} stop_words = ["stop"] streaming_callback = None generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": url}, token=None, generation_kwargs=generation_kwargs, stop_words=stop_words, streaming_callback=streaming_callback, ) assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE assert generator.api_params == {"url": url} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback assert generator.tools is None assert generator._client.model == url assert generator._async_client.model == url def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"} ) def test_init_tgi_no_url(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): duplicate_tools = [tools[0], tools[0]] with pytest.raises(ValueError): HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=duplicate_tools, ) def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=tools, streaming_callback=streaming_callback_handler, ) def test_to_dict(self, mock_check_valid_model): tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], tools=[tool], ) result = generator.to_dict() init_params = result["init_parameters"] assert init_params["api_type"] == "serverless_inference_api" assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} assert init_params["streaming_callback"] is None assert init_params["tools"] == [ { "type": "haystack.tools.tool.Tool", "data": { "description": "description", "function": "builtins.print", "inputs_from_state": None, "name": "name", "outputs_to_state": None, "outputs_to_string": None, "parameters": {"x": {"type": "string"}}, }, } ] def test_from_dict(self, mock_check_valid_model): tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], tools=[tool], ) result = generator.to_dict() # now deserialize, call from_dict generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} assert generator_2.streaming_callback is None assert generator_2.tools == [tool] def test_serde_in_pipeline(self, mock_check_valid_model): tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], tools=[tool], ) pipeline = Pipeline() pipeline.add_component("generator", generator) pipeline_dict = pipeline.to_dict() assert pipeline_dict == { "metadata": {}, "max_runs_per_component": 100, "connection_type_validation": True, "components": { "generator": { "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", "init_parameters": { "api_type": "serverless_inference_api", "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, "streaming_callback": None, "tools": [ { "type": "haystack.tools.tool.Tool", "data": { "inputs_from_state": None, "name": "name", "outputs_to_state": None, "outputs_to_string": None, "description": "description", "parameters": {"x": {"type": "string"}}, "function": "builtins.print", }, } ], }, } }, "connections": [], } pipeline_yaml = pipeline.dumps() new_pipeline = Pipeline.loads(pipeline_yaml) assert new_pipeline == pipeline def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], streaming_callback=None, ) response = generator.run(messages=chat_messages) # check kwargs passed to chat_completion _, kwargs = mock_chat_completion.call_args hf_messages = [ {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, {"role": "user", "content": "Tell me about Berlin"}, ] assert kwargs == { "temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512, "tools": None, "messages": hf_messages, } assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): streaming_call_count = 0 # Define the streaming callback function def streaming_callback_fn(chunk: StreamingChunk): nonlocal streaming_call_count streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, streaming_callback=streaming_callback_fn, ) # Create a fake streamed response # self needed here, don't remove def mock_iter(self): yield ChatCompletionStreamOutput( choices=[ ChatCompletionStreamOutputChoice( delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), index=0, finish_reason=None, ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", created=1710498504, ) yield ChatCompletionStreamOutput( choices=[ ChatCompletionStreamOutputChoice( delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", created=1710498504, ) mock_response = Mock(**{"__iter__": mock_iter}) mock_chat_completion.return_value = mock_response # Generate text response with streaming callback response = generator.run(chat_messages) # check kwargs passed to text_generation _, kwargs = mock_chat_completion.call_args assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 # Assert that the response contains the generated replies assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_run_with_streaming_callback_in_run_method( self, mock_check_valid_model, mock_chat_completion, chat_messages ): streaming_call_count = 0 # Define the streaming callback function def streaming_callback_fn(chunk: StreamingChunk): nonlocal streaming_call_count streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, ) # Create a fake streamed response # self needed here, don't remove def mock_iter(self): yield ChatCompletionStreamOutput( choices=[ ChatCompletionStreamOutputChoice( delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), index=0, finish_reason=None, ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", created=1710498504, ) yield ChatCompletionStreamOutput( choices=[ ChatCompletionStreamOutputChoice( delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", created=1710498504, ) mock_response = Mock(**{"__iter__": mock_iter}) mock_chat_completion.return_value = mock_response # Generate text response with streaming callback response = generator.run(chat_messages, streaming_callback=streaming_callback_fn) # check kwargs passed to text_generation _, kwargs = mock_chat_completion.call_args assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 # Assert that the response contains the generated replies assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): component = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, streaming_callback=streaming_callback_handler, ) with pytest.raises(ValueError): message = ChatMessage.from_user("irrelevant") component.run([message], tools=tools) def test_run_with_tools(self, mock_check_valid_model, tools): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, tools=tools, ) with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: completion = ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason="stop", index=0, message=ChatCompletionOutputMessage( role="assistant", content=None, tool_calls=[ ChatCompletionOutputToolCall( function=ChatCompletionOutputFunctionDefinition( arguments={"city": "Paris"}, name="weather", description=None ), id="0", type="function", ) ], ), logprobs=None, ) ], created=1729074760, id="", model="meta-llama/Llama-3.1-70B-Instruct", system_fingerprint="2.3.2-dev0-sha-28bb7ae", usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), ) mock_chat_completion.return_value = completion messages = [ChatMessage.from_user("What is the weather in Paris?")] response = generator.run(messages=messages) assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert response["replies"][0].tool_calls[0].tool_name == "weather" assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} assert response["replies"][0].tool_calls[0].id == "0" assert response["replies"][0].meta == { "finish_reason": "stop", "index": 0, "model": "meta-llama/Llama-3.1-70B-Instruct", "usage": {"completion_tokens": 30, "prompt_tokens": 426}, } def test_convert_hfapi_tool_calls_empty(self): hfapi_tool_calls = None tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 0 hfapi_tool_calls = [] tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 0 def test_convert_hfapi_tool_calls_dict_arguments(self): hfapi_tool_calls = [ ChatCompletionOutputToolCall( function=ChatCompletionOutputFunctionDefinition( arguments={"city": "Paris"}, name="weather", description=None ), id="0", type="function", ) ] tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 1 assert tool_calls[0].tool_name == "weather" assert tool_calls[0].arguments == {"city": "Paris"} assert tool_calls[0].id == "0" def test_convert_hfapi_tool_calls_str_arguments(self): hfapi_tool_calls = [ ChatCompletionOutputToolCall( function=ChatCompletionOutputFunctionDefinition( arguments='{"city": "Paris"}', name="weather", description=None ), id="0", type="function", ) ] tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 1 assert tool_calls[0].tool_name == "weather" assert tool_calls[0].arguments == {"city": "Paris"} assert tool_calls[0].id == "0" def test_convert_hfapi_tool_calls_invalid_str_arguments(self): hfapi_tool_calls = [ ChatCompletionOutputToolCall( function=ChatCompletionOutputFunctionDefinition( arguments="not a valid JSON string", name="weather", description=None ), id="0", type="function", ) ] tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 0 def test_convert_hfapi_tool_calls_invalid_type_arguments(self): hfapi_tool_calls = [ ChatCompletionOutputToolCall( function=ChatCompletionOutputFunctionDefinition( arguments=["this", "is", "a", "list"], name="weather", description=None ), id="0", type="function", ) ] tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) assert len(tool_calls) == 0 @pytest.mark.integration @pytest.mark.slow @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) @pytest.mark.flaky(reruns=3, reruns_delay=10) def test_live_run_serverless(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "microsoft/Phi-3.5-mini-instruct"}, generation_kwargs={"max_tokens": 20}, ) # No need for instruction tokens here since we use the chat_completion endpoint which handles the chat # templating for us. messages = [ ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.") ] response = generator.run(messages=messages) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "usage" in response["replies"][0].meta assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] @pytest.mark.integration @pytest.mark.slow @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) @pytest.mark.flaky(reruns=3, reruns_delay=10) def test_live_run_serverless_streaming(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "microsoft/Phi-3.5-mini-instruct"}, generation_kwargs={"max_tokens": 20}, streaming_callback=streaming_callback_handler, ) # No need for instruction tokens here since we use the chat_completion endpoint which handles the chat # templating for us. messages = [ ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.") ] response = generator.run(messages=messages) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] response_meta = response["replies"][0].meta assert "completion_start_time" in response_meta assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now() assert "usage" in response_meta assert "prompt_tokens" in response_meta["usage"] assert "completion_tokens" in response_meta["usage"] @pytest.mark.integration @pytest.mark.slow @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) def test_live_run_with_tools(self, tools): """ We test the round trip: generate tool call, pass tool message, generate response. The model used here (Hermes-3-Llama-3.1-8B) is not gated and kept in a warm state. """ chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "NousResearch/Hermes-3-Llama-3.1-8B"}, generation_kwargs={"temperature": 0.5}, ) results = generator.run(chat_messages, tools=tools) assert len(results["replies"]) == 1 message = results["replies"][0] assert message.tool_calls tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert "city" in tool_call.arguments assert "Paris" in tool_call.arguments["city"] assert message.meta["finish_reason"] == "stop" new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] # the model tends to make tool calls if provided with tools, so we don't pass them here results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) assert len(results["replies"]) == 1 final_message = results["replies"][0] assert not final_message.tool_calls assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() @pytest.mark.asyncio async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], streaming_callback=None, ) response = await generator.run_async(messages=chat_messages) # check kwargs passed to chat_completion _, kwargs = mock_chat_completion_async.call_args hf_messages = [ {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, {"role": "user", "content": "Tell me about Berlin"}, ] assert kwargs == { "temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512, "tools": None, "messages": hf_messages, } assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] @pytest.mark.asyncio async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages): streaming_call_count = 0 async def streaming_callback_fn(chunk: StreamingChunk): nonlocal streaming_call_count streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) # Create a fake streamed response async def mock_aiter(self): yield ChatCompletionStreamOutput( choices=[ ChatCompletionStreamOutputChoice( delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), index=0, finish_reason=None, ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", created=1710498504, ) yield ChatCompletionStreamOutput( choices=[ ChatCompletionStreamOutputChoice( delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" ) ], id="some_id", model="some_model", system_fingerprint="some_fingerprint", created=1710498504, ) mock_response = Mock(**{"__aiter__": mock_aiter}) mock_chat_completion_async.return_value = mock_response generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, streaming_callback=streaming_callback_fn, ) response = await generator.run_async(messages=chat_messages) # check kwargs passed to chat_completion _, kwargs = mock_chat_completion_async.call_args assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 # Assert that the response contains the generated replies assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] @pytest.mark.asyncio async def test_run_async_with_tools(self, tools, mock_check_valid_model): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, tools=tools, ) with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async: completion = ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason="stop", index=0, message=ChatCompletionOutputMessage( role="assistant", content=None, tool_calls=[ ChatCompletionOutputToolCall( function=ChatCompletionOutputFunctionDefinition( arguments={"city": "Paris"}, name="weather", description=None ), id="0", type="function", ) ], ), logprobs=None, ) ], created=1729074760, id="", model="meta-llama/Llama-3.1-70B-Instruct", system_fingerprint="2.3.2-dev0-sha-28bb7ae", usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), ) mock_chat_completion_async.return_value = completion messages = [ChatMessage.from_user("What is the weather in Paris?")] response = await generator.run_async(messages=messages) assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert response["replies"][0].tool_calls[0].tool_name == "weather" assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} assert response["replies"][0].tool_calls[0].id == "0" assert response["replies"][0].meta == { "finish_reason": "stop", "index": 0, "model": "meta-llama/Llama-3.1-70B-Instruct", "usage": {"completion_tokens": 30, "prompt_tokens": 426}, } @pytest.mark.integration @pytest.mark.slow @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) @pytest.mark.flaky(reruns=3, reruns_delay=10) @pytest.mark.asyncio async def test_live_run_async_serverless(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "microsoft/Phi-3.5-mini-instruct"}, generation_kwargs={"max_tokens": 20}, ) messages = [ ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.") ] try: response = await generator.run_async(messages=messages) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "usage" in response["replies"][0].meta assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] finally: await generator._async_client.close() def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools): """Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset.""" toolset = Toolset(tools) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset ) assert generator.tools == toolset def test_from_dict_with_toolset(self, mock_check_valid_model, tools): """Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset.""" toolset = Toolset(tools) component = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset ) data = component.to_dict() deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data) assert isinstance(deserialized_component.tools, Toolset) assert len(deserialized_component.tools) == len(tools) assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) def test_to_dict_with_toolset(self, mock_check_valid_model, tools): """Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset.""" toolset = Toolset(tools) generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset ) data = generator.to_dict() expected_tools_data = { "type": "haystack.tools.toolset.Toolset", "data": { "tools": [ { "type": "haystack.tools.tool.Tool", "data": { "name": "weather", "description": "useful to determine the weather in a given location", "parameters": { "type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"], }, "function": "generators.chat.test_hugging_face_api.get_weather", "outputs_to_string": None, "inputs_from_state": None, "outputs_to_state": None, }, } ] }, } assert data["init_parameters"]["tools"] == expected_tools_data