diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 9a858a2f9..8ddf38530 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 from datetime import datetime -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -15,6 +15,7 @@ from haystack.utils.url_validation import is_valid_http_url with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import ( + AsyncInferenceClient, ChatCompletionInputFunctionDefinition, ChatCompletionInputTool, ChatCompletionOutput, @@ -181,6 +182,7 @@ class HuggingFaceAPIChatGenerator: self.generation_kwargs = generation_kwargs self.streaming_callback = streaming_callback self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None) self.tools = tools def to_dict(self) -> Dict[str, Any]: @@ -250,7 +252,11 @@ class HuggingFaceAPIChatGenerator: raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") _check_duplicate_tool_names(tools) - streaming_callback = streaming_callback or self.streaming_callback + # validate and select the streaming callback + streaming_callback = select_streaming_callback( + self.streaming_callback, streaming_callback, requires_async=False + ) # type: ignore + if streaming_callback: return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback) @@ -267,6 +273,63 @@ class HuggingFaceAPIChatGenerator: ] return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) + @component.output_types(replies=List[ChatMessage]) + async def run_async( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): + """ + Asynchronously invokes the text generation inference based on the provided messages and generation parameters. + + This is the asynchronous version of the `run` method. It has the same parameters + and return values but can be used with `await` in an async code. + + :param messages: + A list of ChatMessage objects representing the input messages. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :param streaming_callback: + An optional callable for handling streaming responses. If set, it will override the `streaming_callback` + parameter set during component initialization. + :returns: A dictionary with the following keys: + - `replies`: A list containing the generated responses as ChatMessage objects. + """ + + # update generation kwargs by merging with the default ones + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + formatted_messages = [convert_message_to_hf_format(message) for message in messages] + + tools = tools or self.tools + if tools and self.streaming_callback: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + + # validate and select the streaming callback + streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore + + if streaming_callback: + return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback) + + hf_tools = None + if tools: + hf_tools = [ + ChatCompletionInputTool( + function=ChatCompletionInputFunctionDefinition( + name=tool.name, description=tool.description, arguments=tool.parameters + ), + type="function", + ) + for tool in tools + ] + return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools) + def _run_streaming( self, messages: List[Dict[str, str]], @@ -359,3 +422,89 @@ class HuggingFaceAPIChatGenerator: message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) return {"replies": [message]} + + async def _run_streaming_async( + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + streaming_callback: Callable[[StreamingChunk], None], + ): + api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion( + messages, stream=True, **generation_kwargs + ) + + generated_text = "" + first_chunk_time = None + + async for chunk in api_output: + choice = chunk.choices[0] + + text = choice.delta.content or "" + generated_text += text + + finish_reason = choice.finish_reason + + meta: Dict[str, Any] = {} + if finish_reason: + meta["finish_reason"] = finish_reason + + if first_chunk_time is None: + first_chunk_time = datetime.now().isoformat() + + stream_chunk = StreamingChunk(text, meta) + await streaming_callback(stream_chunk) # type: ignore + + meta.update( + { + "model": self._async_client.model, + "finish_reason": finish_reason, + "index": 0, + "usage": {"prompt_tokens": 0, "completion_tokens": 0}, + "completion_start_time": first_chunk_time, + } + ) + + message = ChatMessage.from_assistant(text=generated_text, meta=meta) + return {"replies": [message]} + + async def _run_non_streaming_async( + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + tools: Optional[List["ChatCompletionInputTool"]] = None, + ) -> Dict[str, List[ChatMessage]]: + api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion( + messages=messages, tools=tools, **generation_kwargs + ) + + if len(api_chat_output.choices) == 0: + return {"replies": []} + + choice = api_chat_output.choices[0] + + text = choice.message.content + tool_calls = [] + + if hfapi_tool_calls := choice.message.tool_calls: + for hfapi_tc in hfapi_tool_calls: + tool_call = ToolCall( + tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id + ) + tool_calls.append(tool_call) + + meta: Dict[str, Any] = { + "model": self._async_client.model, + "finish_reason": choice.finish_reason, + "index": choice.index, + } + + usage = {"prompt_tokens": 0, "completion_tokens": 0} + if api_chat_output.usage: + usage = { + "prompt_tokens": api_chat_output.usage.prompt_tokens, + "completion_tokens": api_chat_output.usage.completion_tokens, + } + meta["usage"] = usage + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return {"replies": [message]} diff --git a/pyproject.toml b/pyproject.toml index f4175fd14..1f81d2101 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -278,6 +278,7 @@ markers = [ ] log_cli = true asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "class" [tool.mypy] warn_return_any = false diff --git a/releasenotes/notes/add-run-async-to-huggingfaceapichatgenerator-312439de9951bfdc.yaml b/releasenotes/notes/add-run-async-to-huggingfaceapichatgenerator-312439de9951bfdc.yaml new file mode 100644 index 000000000..8a13bc835 --- /dev/null +++ b/releasenotes/notes/add-run-async-to-huggingfaceapichatgenerator-312439de9951bfdc.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Add `run_async` method to HuggingFaceAPIChatGenerator. This method relies internally on the `AsyncInferenceClient` from huggingface + to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine + that can be awaited. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 13fe00990..8a413d59b 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from datetime import datetime import os -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, AsyncMock, patch import pytest from haystack import Pipeline @@ -81,6 +81,31 @@ def mock_chat_completion(): yield mock_chat_completion +@pytest.fixture +def mock_chat_completion_async(): + with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), + created=1710498360, + ) + + # Use AsyncMock to properly mock the async method + mock_chat_completion.return_value = completion + mock_chat_completion.__call__ = AsyncMock(return_value=completion) + + yield mock_chat_completion + + # used to test serialization of streaming_callback def streaming_callback_handler(x): return x @@ -112,6 +137,10 @@ class TestHuggingFaceAPIChatGenerator: assert generator.streaming_callback == streaming_callback assert generator.tools is None + # check that client and async_client are initialized + assert generator._client.model == model + assert generator._async_client.model == model + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): model = "HuggingFaceH4/zephyr-7b-alpha" generation_kwargs = {"temperature": 0.6} @@ -134,6 +163,9 @@ class TestHuggingFaceAPIChatGenerator: assert generator.streaming_callback == streaming_callback assert generator.tools == tools + assert generator._client.model == model + assert generator._async_client.model == model + def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError): @@ -168,6 +200,9 @@ class TestHuggingFaceAPIChatGenerator: assert generator.streaming_callback == streaming_callback assert generator.tools is None + assert generator._client.model == url + assert generator._async_client.model == url + def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator( @@ -623,3 +658,176 @@ class TestHuggingFaceAPIChatGenerator: assert not final_message.tool_calls assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() + + @pytest.mark.asyncio + async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + streaming_callback=None, + ) + + response = await generator.run_async(messages=chat_messages) + + # check kwargs passed to chat_completion + _, kwargs = mock_chat_completion_async.call_args + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] + assert kwargs == { + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, + } + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.asyncio + async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages): + streaming_call_count = 0 + + async def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + # Create a fake streamed response + async def mock_aiter(self): + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), + index=0, + finish_reason=None, + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, + ) + + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, + ) + + mock_response = Mock(**{"__aiter__": mock_aiter}) + mock_chat_completion_async.return_value = mock_response + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_fn, + ) + + response = await generator.run_async(messages=chat_messages) + + # check kwargs passed to chat_completion + _, kwargs = mock_chat_completion_async.call_args + assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.asyncio + async def test_run_async_with_tools(self, tools, mock_check_valid_model): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) + + with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion_async.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = await generator.run_async(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.flaky(reruns=3, reruns_delay=10) + async def test_live_run_async_serverless(self): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"max_tokens": 20}, + ) + + messages = [ChatMessage.from_user("What is the capital of France?")] + response = await generator.run_async(messages=messages) + + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"]