feat: add run_async to HuggingfaceAPIChatGenerator (#8943)

* add run_async

* add release notes

* Add integration test
This commit is contained in:
Amna Mubashar 2025-03-03 20:51:30 +05:00 committed by GitHub
parent 1b2053b358
commit 28db039bca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 368 additions and 4 deletions

View File

@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Union from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@ -15,6 +15,7 @@ from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import ( from huggingface_hub import (
AsyncInferenceClient,
ChatCompletionInputFunctionDefinition, ChatCompletionInputFunctionDefinition,
ChatCompletionInputTool, ChatCompletionInputTool,
ChatCompletionOutput, ChatCompletionOutput,
@ -181,6 +182,7 @@ class HuggingFaceAPIChatGenerator:
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
self.tools = tools self.tools = tools
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@ -250,7 +252,11 @@ class HuggingFaceAPIChatGenerator:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools) _check_duplicate_tool_names(tools)
streaming_callback = streaming_callback or self.streaming_callback # validate and select the streaming callback
streaming_callback = select_streaming_callback(
self.streaming_callback, streaming_callback, requires_async=False
) # type: ignore
if streaming_callback: if streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback) return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
@ -267,6 +273,63 @@ class HuggingFaceAPIChatGenerator:
] ]
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
@component.output_types(replies=List[ChatMessage])
async def run_async(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
This is the asynchronous version of the `run` method. It has the same parameters
and return values but can be used with `await` in an async code.
:param messages:
A list of ChatMessage objects representing the input messages.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
tools = tools or self.tools
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
# validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore
if streaming_callback:
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
hf_tools = None
if tools:
hf_tools = [
ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition(
name=tool.name, description=tool.description, arguments=tool.parameters
),
type="function",
)
for tool in tools
]
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
def _run_streaming( def _run_streaming(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
@ -359,3 +422,89 @@ class HuggingFaceAPIChatGenerator:
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]} return {"replies": [message]}
async def _run_streaming_async(
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
streaming_callback: Callable[[StreamingChunk], None],
):
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
messages, stream=True, **generation_kwargs
)
generated_text = ""
first_chunk_time = None
async for chunk in api_output:
choice = chunk.choices[0]
text = choice.delta.content or ""
generated_text += text
finish_reason = choice.finish_reason
meta: Dict[str, Any] = {}
if finish_reason:
meta["finish_reason"] = finish_reason
if first_chunk_time is None:
first_chunk_time = datetime.now().isoformat()
stream_chunk = StreamingChunk(text, meta)
await streaming_callback(stream_chunk) # type: ignore
meta.update(
{
"model": self._async_client.model,
"finish_reason": finish_reason,
"index": 0,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
"completion_start_time": first_chunk_time,
}
)
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
return {"replies": [message]}
async def _run_non_streaming_async(
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
tools: Optional[List["ChatCompletionInputTool"]] = None,
) -> Dict[str, List[ChatMessage]]:
api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
messages=messages, tools=tools, **generation_kwargs
)
if len(api_chat_output.choices) == 0:
return {"replies": []}
choice = api_chat_output.choices[0]
text = choice.message.content
tool_calls = []
if hfapi_tool_calls := choice.message.tool_calls:
for hfapi_tc in hfapi_tool_calls:
tool_call = ToolCall(
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
)
tool_calls.append(tool_call)
meta: Dict[str, Any] = {
"model": self._async_client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
}
usage = {"prompt_tokens": 0, "completion_tokens": 0}
if api_chat_output.usage:
usage = {
"prompt_tokens": api_chat_output.usage.prompt_tokens,
"completion_tokens": api_chat_output.usage.completion_tokens,
}
meta["usage"] = usage
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}

View File

@ -278,6 +278,7 @@ markers = [
] ]
log_cli = true log_cli = true
asyncio_mode = "auto" asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "class"
[tool.mypy] [tool.mypy]
warn_return_any = false warn_return_any = false

View File

@ -0,0 +1,6 @@
---
features:
- |
Add `run_async` method to HuggingFaceAPIChatGenerator. This method relies internally on the `AsyncInferenceClient` from huggingface
to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine
that can be awaited.

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from datetime import datetime from datetime import datetime
import os import os
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, AsyncMock, patch
import pytest import pytest
from haystack import Pipeline from haystack import Pipeline
@ -81,6 +81,31 @@ def mock_chat_completion():
yield mock_chat_completion yield mock_chat_completion
@pytest.fixture
def mock_chat_completion_async():
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion:
completion = ChatCompletionOutput(
choices=[
ChatCompletionOutputComplete(
finish_reason="eos_token",
index=0,
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
)
],
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
created=1710498360,
)
# Use AsyncMock to properly mock the async method
mock_chat_completion.return_value = completion
mock_chat_completion.__call__ = AsyncMock(return_value=completion)
yield mock_chat_completion
# used to test serialization of streaming_callback # used to test serialization of streaming_callback
def streaming_callback_handler(x): def streaming_callback_handler(x):
return x return x
@ -112,6 +137,10 @@ class TestHuggingFaceAPIChatGenerator:
assert generator.streaming_callback == streaming_callback assert generator.streaming_callback == streaming_callback
assert generator.tools is None assert generator.tools is None
# check that client and async_client are initialized
assert generator._client.model == model
assert generator._async_client.model == model
def test_init_serverless_with_tools(self, mock_check_valid_model, tools): def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
model = "HuggingFaceH4/zephyr-7b-alpha" model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"temperature": 0.6} generation_kwargs = {"temperature": 0.6}
@ -134,6 +163,9 @@ class TestHuggingFaceAPIChatGenerator:
assert generator.streaming_callback == streaming_callback assert generator.streaming_callback == streaming_callback
assert generator.tools == tools assert generator.tools == tools
assert generator._client.model == model
assert generator._async_client.model == model
def test_init_serverless_invalid_model(self, mock_check_valid_model): def test_init_serverless_invalid_model(self, mock_check_valid_model):
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
with pytest.raises(RepositoryNotFoundError): with pytest.raises(RepositoryNotFoundError):
@ -168,6 +200,9 @@ class TestHuggingFaceAPIChatGenerator:
assert generator.streaming_callback == streaming_callback assert generator.streaming_callback == streaming_callback
assert generator.tools is None assert generator.tools is None
assert generator._client.model == url
assert generator._async_client.model == url
def test_init_tgi_invalid_url(self): def test_init_tgi_invalid_url(self):
with pytest.raises(ValueError): with pytest.raises(ValueError):
HuggingFaceAPIChatGenerator( HuggingFaceAPIChatGenerator(
@ -623,3 +658,176 @@ class TestHuggingFaceAPIChatGenerator:
assert not final_message.tool_calls assert not final_message.tool_calls
assert len(final_message.text) > 0 assert len(final_message.text) > 0
assert "paris" in final_message.text.lower() assert "paris" in final_message.text.lower()
@pytest.mark.asyncio
async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
streaming_callback=None,
)
response = await generator.run_async(messages=chat_messages)
# check kwargs passed to chat_completion
_, kwargs = mock_chat_completion_async.call_args
hf_messages = [
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
{"role": "user", "content": "Tell me about Berlin"},
]
assert kwargs == {
"temperature": 0.6,
"stop": ["stop", "words"],
"max_tokens": 512,
"tools": None,
"messages": hf_messages,
}
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.asyncio
async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
streaming_call_count = 0
async def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
# Create a fake streamed response
async def mock_aiter(self):
yield ChatCompletionStreamOutput(
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
index=0,
finish_reason=None,
)
],
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
created=1710498504,
)
yield ChatCompletionStreamOutput(
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
)
],
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
created=1710498504,
)
mock_response = Mock(**{"__aiter__": mock_aiter})
mock_chat_completion_async.return_value = mock_response
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
streaming_callback=streaming_callback_fn,
)
response = await generator.run_async(messages=chat_messages)
# check kwargs passed to chat_completion
_, kwargs = mock_chat_completion_async.call_args
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.asyncio
async def test_run_async_with_tools(self, tools, mock_check_valid_model):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
tools=tools,
)
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async:
completion = ChatCompletionOutput(
choices=[
ChatCompletionOutputComplete(
finish_reason="stop",
index=0,
message=ChatCompletionOutputMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionOutputToolCall(
function=ChatCompletionOutputFunctionDefinition(
arguments={"city": "Paris"}, name="weather", description=None
),
id="0",
type="function",
)
],
),
logprobs=None,
)
],
created=1729074760,
id="",
model="meta-llama/Llama-3.1-70B-Instruct",
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
)
mock_chat_completion_async.return_value = completion
messages = [ChatMessage.from_user("What is the weather in Paris?")]
response = await generator.run_async(messages=messages)
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert response["replies"][0].tool_calls[0].tool_name == "weather"
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
assert response["replies"][0].tool_calls[0].id == "0"
assert response["replies"][0].meta == {
"finish_reason": "stop",
"index": 0,
"model": "meta-llama/Llama-3.1-70B-Instruct",
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
@pytest.mark.flaky(reruns=3, reruns_delay=10)
async def test_live_run_async_serverless(self):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
generation_kwargs={"max_tokens": 20},
)
messages = [ChatMessage.from_user("What is the capital of France?")]
response = await generator.run_async(messages=messages)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "usage" in response["replies"][0].meta
assert "prompt_tokens" in response["replies"][0].meta["usage"]
assert "completion_tokens" in response["replies"][0].meta["usage"]