feat: add run_async to HuggingfaceAPIChatGenerator (#8943)

* add run_async

* add release notes

* Add integration test
This commit is contained in:
Amna Mubashar 2025-03-03 20:51:30 +05:00 committed by GitHub
parent 1b2053b358
commit 28db039bca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 368 additions and 4 deletions

View File

@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@ -15,6 +15,7 @@ from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
AsyncInferenceClient,
ChatCompletionInputFunctionDefinition,
ChatCompletionInputTool,
ChatCompletionOutput,
@ -181,6 +182,7 @@ class HuggingFaceAPIChatGenerator:
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
self.tools = tools
def to_dict(self) -> Dict[str, Any]:
@ -250,7 +252,11 @@ class HuggingFaceAPIChatGenerator:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
streaming_callback = streaming_callback or self.streaming_callback
# validate and select the streaming callback
streaming_callback = select_streaming_callback(
self.streaming_callback, streaming_callback, requires_async=False
) # type: ignore
if streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
@ -267,6 +273,63 @@ class HuggingFaceAPIChatGenerator:
]
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
@component.output_types(replies=List[ChatMessage])
async def run_async(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
This is the asynchronous version of the `run` method. It has the same parameters
and return values but can be used with `await` in an async code.
:param messages:
A list of ChatMessage objects representing the input messages.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
tools = tools or self.tools
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
# validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore
if streaming_callback:
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
hf_tools = None
if tools:
hf_tools = [
ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition(
name=tool.name, description=tool.description, arguments=tool.parameters
),
type="function",
)
for tool in tools
]
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
def _run_streaming(
self,
messages: List[Dict[str, str]],
@ -359,3 +422,89 @@ class HuggingFaceAPIChatGenerator:
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}
async def _run_streaming_async(
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
streaming_callback: Callable[[StreamingChunk], None],
):
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
messages, stream=True, **generation_kwargs
)
generated_text = ""
first_chunk_time = None
async for chunk in api_output:
choice = chunk.choices[0]
text = choice.delta.content or ""
generated_text += text
finish_reason = choice.finish_reason
meta: Dict[str, Any] = {}
if finish_reason:
meta["finish_reason"] = finish_reason
if first_chunk_time is None:
first_chunk_time = datetime.now().isoformat()
stream_chunk = StreamingChunk(text, meta)
await streaming_callback(stream_chunk) # type: ignore
meta.update(
{
"model": self._async_client.model,
"finish_reason": finish_reason,
"index": 0,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
"completion_start_time": first_chunk_time,
}
)
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
return {"replies": [message]}
async def _run_non_streaming_async(
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
tools: Optional[List["ChatCompletionInputTool"]] = None,
) -> Dict[str, List[ChatMessage]]:
api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
messages=messages, tools=tools, **generation_kwargs
)
if len(api_chat_output.choices) == 0:
return {"replies": []}
choice = api_chat_output.choices[0]
text = choice.message.content
tool_calls = []
if hfapi_tool_calls := choice.message.tool_calls:
for hfapi_tc in hfapi_tool_calls:
tool_call = ToolCall(
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
)
tool_calls.append(tool_call)
meta: Dict[str, Any] = {
"model": self._async_client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
}
usage = {"prompt_tokens": 0, "completion_tokens": 0}
if api_chat_output.usage:
usage = {
"prompt_tokens": api_chat_output.usage.prompt_tokens,
"completion_tokens": api_chat_output.usage.completion_tokens,
}
meta["usage"] = usage
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}

View File

@ -278,6 +278,7 @@ markers = [
]
log_cli = true
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "class"
[tool.mypy]
warn_return_any = false

View File

@ -0,0 +1,6 @@
---
features:
- |
Add `run_async` method to HuggingFaceAPIChatGenerator. This method relies internally on the `AsyncInferenceClient` from huggingface
to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine
that can be awaited.

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
import os
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, AsyncMock, patch
import pytest
from haystack import Pipeline
@ -81,6 +81,31 @@ def mock_chat_completion():
yield mock_chat_completion
@pytest.fixture
def mock_chat_completion_async():
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion:
completion = ChatCompletionOutput(
choices=[
ChatCompletionOutputComplete(
finish_reason="eos_token",
index=0,
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
)
],
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
created=1710498360,
)
# Use AsyncMock to properly mock the async method
mock_chat_completion.return_value = completion
mock_chat_completion.__call__ = AsyncMock(return_value=completion)
yield mock_chat_completion
# used to test serialization of streaming_callback
def streaming_callback_handler(x):
return x
@ -112,6 +137,10 @@ class TestHuggingFaceAPIChatGenerator:
assert generator.streaming_callback == streaming_callback
assert generator.tools is None
# check that client and async_client are initialized
assert generator._client.model == model
assert generator._async_client.model == model
def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"temperature": 0.6}
@ -134,6 +163,9 @@ class TestHuggingFaceAPIChatGenerator:
assert generator.streaming_callback == streaming_callback
assert generator.tools == tools
assert generator._client.model == model
assert generator._async_client.model == model
def test_init_serverless_invalid_model(self, mock_check_valid_model):
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
with pytest.raises(RepositoryNotFoundError):
@ -168,6 +200,9 @@ class TestHuggingFaceAPIChatGenerator:
assert generator.streaming_callback == streaming_callback
assert generator.tools is None
assert generator._client.model == url
assert generator._async_client.model == url
def test_init_tgi_invalid_url(self):
with pytest.raises(ValueError):
HuggingFaceAPIChatGenerator(
@ -623,3 +658,176 @@ class TestHuggingFaceAPIChatGenerator:
assert not final_message.tool_calls
assert len(final_message.text) > 0
assert "paris" in final_message.text.lower()
@pytest.mark.asyncio
async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
streaming_callback=None,
)
response = await generator.run_async(messages=chat_messages)
# check kwargs passed to chat_completion
_, kwargs = mock_chat_completion_async.call_args
hf_messages = [
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
{"role": "user", "content": "Tell me about Berlin"},
]
assert kwargs == {
"temperature": 0.6,
"stop": ["stop", "words"],
"max_tokens": 512,
"tools": None,
"messages": hf_messages,
}
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.asyncio
async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
streaming_call_count = 0
async def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
# Create a fake streamed response
async def mock_aiter(self):
yield ChatCompletionStreamOutput(
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
index=0,
finish_reason=None,
)
],
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
created=1710498504,
)
yield ChatCompletionStreamOutput(
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
)
],
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
created=1710498504,
)
mock_response = Mock(**{"__aiter__": mock_aiter})
mock_chat_completion_async.return_value = mock_response
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
streaming_callback=streaming_callback_fn,
)
response = await generator.run_async(messages=chat_messages)
# check kwargs passed to chat_completion
_, kwargs = mock_chat_completion_async.call_args
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.asyncio
async def test_run_async_with_tools(self, tools, mock_check_valid_model):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
tools=tools,
)
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async:
completion = ChatCompletionOutput(
choices=[
ChatCompletionOutputComplete(
finish_reason="stop",
index=0,
message=ChatCompletionOutputMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionOutputToolCall(
function=ChatCompletionOutputFunctionDefinition(
arguments={"city": "Paris"}, name="weather", description=None
),
id="0",
type="function",
)
],
),
logprobs=None,
)
],
created=1729074760,
id="",
model="meta-llama/Llama-3.1-70B-Instruct",
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
)
mock_chat_completion_async.return_value = completion
messages = [ChatMessage.from_user("What is the weather in Paris?")]
response = await generator.run_async(messages=messages)
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert response["replies"][0].tool_calls[0].tool_name == "weather"
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
assert response["replies"][0].tool_calls[0].id == "0"
assert response["replies"][0].meta == {
"finish_reason": "stop",
"index": 0,
"model": "meta-llama/Llama-3.1-70B-Instruct",
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
@pytest.mark.flaky(reruns=3, reruns_delay=10)
async def test_live_run_async_serverless(self):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
generation_kwargs={"max_tokens": 20},
)
messages = [ChatMessage.from_user("What is the capital of France?")]
response = await generator.run_async(messages=messages)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "usage" in response["replies"][0].meta
assert "prompt_tokens" in response["replies"][0].meta["usage"]
assert "completion_tokens" in response["replies"][0].meta["usage"]