mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 15:57:24 +00:00
feat: add run_async to HuggingfaceAPIChatGenerator (#8943)
* add run_async * add release notes * Add integration test
This commit is contained in:
parent
1b2053b358
commit
28db039bca
@ -3,10 +3,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
@ -15,6 +15,7 @@ from haystack.utils.url_validation import is_valid_http_url
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import (
|
||||
AsyncInferenceClient,
|
||||
ChatCompletionInputFunctionDefinition,
|
||||
ChatCompletionInputTool,
|
||||
ChatCompletionOutput,
|
||||
@ -181,6 +182,7 @@ class HuggingFaceAPIChatGenerator:
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.streaming_callback = streaming_callback
|
||||
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||
self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -250,7 +252,11 @@ class HuggingFaceAPIChatGenerator:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
# validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(
|
||||
self.streaming_callback, streaming_callback, requires_async=False
|
||||
) # type: ignore
|
||||
|
||||
if streaming_callback:
|
||||
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
|
||||
|
||||
@ -267,6 +273,63 @@ class HuggingFaceAPIChatGenerator:
|
||||
]
|
||||
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
|
||||
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
async def run_async(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
):
|
||||
"""
|
||||
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
This is the asynchronous version of the `run` method. It has the same parameters
|
||||
and return values but can be used with `await` in an async code.
|
||||
|
||||
:param messages:
|
||||
A list of ChatMessage objects representing the input messages.
|
||||
:param generation_kwargs:
|
||||
Additional keyword arguments for text generation.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
||||
during component initialization.
|
||||
:param streaming_callback:
|
||||
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
||||
parameter set during component initialization.
|
||||
:returns: A dictionary with the following keys:
|
||||
- `replies`: A list containing the generated responses as ChatMessage objects.
|
||||
"""
|
||||
|
||||
# update generation kwargs by merging with the default ones
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
|
||||
tools = tools or self.tools
|
||||
if tools and self.streaming_callback:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
|
||||
# validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore
|
||||
|
||||
if streaming_callback:
|
||||
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
|
||||
|
||||
hf_tools = None
|
||||
if tools:
|
||||
hf_tools = [
|
||||
ChatCompletionInputTool(
|
||||
function=ChatCompletionInputFunctionDefinition(
|
||||
name=tool.name, description=tool.description, arguments=tool.parameters
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
|
||||
|
||||
def _run_streaming(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
@ -359,3 +422,89 @@ class HuggingFaceAPIChatGenerator:
|
||||
|
||||
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
|
||||
return {"replies": [message]}
|
||||
|
||||
async def _run_streaming_async(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
streaming_callback: Callable[[StreamingChunk], None],
|
||||
):
|
||||
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
|
||||
messages, stream=True, **generation_kwargs
|
||||
)
|
||||
|
||||
generated_text = ""
|
||||
first_chunk_time = None
|
||||
|
||||
async for chunk in api_output:
|
||||
choice = chunk.choices[0]
|
||||
|
||||
text = choice.delta.content or ""
|
||||
generated_text += text
|
||||
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
meta: Dict[str, Any] = {}
|
||||
if finish_reason:
|
||||
meta["finish_reason"] = finish_reason
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = datetime.now().isoformat()
|
||||
|
||||
stream_chunk = StreamingChunk(text, meta)
|
||||
await streaming_callback(stream_chunk) # type: ignore
|
||||
|
||||
meta.update(
|
||||
{
|
||||
"model": self._async_client.model,
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
"completion_start_time": first_chunk_time,
|
||||
}
|
||||
)
|
||||
|
||||
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
|
||||
return {"replies": [message]}
|
||||
|
||||
async def _run_non_streaming_async(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
tools: Optional[List["ChatCompletionInputTool"]] = None,
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
|
||||
messages=messages, tools=tools, **generation_kwargs
|
||||
)
|
||||
|
||||
if len(api_chat_output.choices) == 0:
|
||||
return {"replies": []}
|
||||
|
||||
choice = api_chat_output.choices[0]
|
||||
|
||||
text = choice.message.content
|
||||
tool_calls = []
|
||||
|
||||
if hfapi_tool_calls := choice.message.tool_calls:
|
||||
for hfapi_tc in hfapi_tool_calls:
|
||||
tool_call = ToolCall(
|
||||
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
meta: Dict[str, Any] = {
|
||||
"model": self._async_client.model,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"index": choice.index,
|
||||
}
|
||||
|
||||
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
if api_chat_output.usage:
|
||||
usage = {
|
||||
"prompt_tokens": api_chat_output.usage.prompt_tokens,
|
||||
"completion_tokens": api_chat_output.usage.completion_tokens,
|
||||
}
|
||||
meta["usage"] = usage
|
||||
|
||||
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
|
||||
return {"replies": [message]}
|
||||
|
||||
@ -278,6 +278,7 @@ markers = [
|
||||
]
|
||||
log_cli = true
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "class"
|
||||
|
||||
[tool.mypy]
|
||||
warn_return_any = false
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add `run_async` method to HuggingFaceAPIChatGenerator. This method relies internally on the `AsyncInferenceClient` from huggingface
|
||||
to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine
|
||||
that can be awaited.
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from datetime import datetime
|
||||
import os
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from haystack import Pipeline
|
||||
@ -81,6 +81,31 @@ def mock_chat_completion():
|
||||
yield mock_chat_completion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_completion_async():
|
||||
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
||||
completion = ChatCompletionOutput(
|
||||
choices=[
|
||||
ChatCompletionOutputComplete(
|
||||
finish_reason="eos_token",
|
||||
index=0,
|
||||
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
|
||||
)
|
||||
],
|
||||
id="some_id",
|
||||
model="some_model",
|
||||
system_fingerprint="some_fingerprint",
|
||||
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
|
||||
created=1710498360,
|
||||
)
|
||||
|
||||
# Use AsyncMock to properly mock the async method
|
||||
mock_chat_completion.return_value = completion
|
||||
mock_chat_completion.__call__ = AsyncMock(return_value=completion)
|
||||
|
||||
yield mock_chat_completion
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
def streaming_callback_handler(x):
|
||||
return x
|
||||
@ -112,6 +137,10 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
assert generator.tools is None
|
||||
|
||||
# check that client and async_client are initialized
|
||||
assert generator._client.model == model
|
||||
assert generator._async_client.model == model
|
||||
|
||||
def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
|
||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
generation_kwargs = {"temperature": 0.6}
|
||||
@ -134,6 +163,9 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
assert generator.tools == tools
|
||||
|
||||
assert generator._client.model == model
|
||||
assert generator._async_client.model == model
|
||||
|
||||
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
@ -168,6 +200,9 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
assert generator.tools is None
|
||||
|
||||
assert generator._client.model == url
|
||||
assert generator._async_client.model == url
|
||||
|
||||
def test_init_tgi_invalid_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIChatGenerator(
|
||||
@ -623,3 +658,176 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert not final_message.tool_calls
|
||||
assert len(final_message.text) > 0
|
||||
assert "paris" in final_message.text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
||||
generation_kwargs={"temperature": 0.6},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=None,
|
||||
)
|
||||
|
||||
response = await generator.run_async(messages=chat_messages)
|
||||
|
||||
# check kwargs passed to chat_completion
|
||||
_, kwargs = mock_chat_completion_async.call_args
|
||||
hf_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
|
||||
{"role": "user", "content": "Tell me about Berlin"},
|
||||
]
|
||||
assert kwargs == {
|
||||
"temperature": 0.6,
|
||||
"stop": ["stop", "words"],
|
||||
"max_tokens": 512,
|
||||
"tools": None,
|
||||
"messages": hf_messages,
|
||||
}
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
|
||||
streaming_call_count = 0
|
||||
|
||||
async def streaming_callback_fn(chunk: StreamingChunk):
|
||||
nonlocal streaming_call_count
|
||||
streaming_call_count += 1
|
||||
assert isinstance(chunk, StreamingChunk)
|
||||
|
||||
# Create a fake streamed response
|
||||
async def mock_aiter(self):
|
||||
yield ChatCompletionStreamOutput(
|
||||
choices=[
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
||||
index=0,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
id="some_id",
|
||||
model="some_model",
|
||||
system_fingerprint="some_fingerprint",
|
||||
created=1710498504,
|
||||
)
|
||||
|
||||
yield ChatCompletionStreamOutput(
|
||||
choices=[
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
||||
)
|
||||
],
|
||||
id="some_id",
|
||||
model="some_model",
|
||||
system_fingerprint="some_fingerprint",
|
||||
created=1710498504,
|
||||
)
|
||||
|
||||
mock_response = Mock(**{"__aiter__": mock_aiter})
|
||||
mock_chat_completion_async.return_value = mock_response
|
||||
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
||||
streaming_callback=streaming_callback_fn,
|
||||
)
|
||||
|
||||
response = await generator.run_async(messages=chat_messages)
|
||||
|
||||
# check kwargs passed to chat_completion
|
||||
_, kwargs = mock_chat_completion_async.call_args
|
||||
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_tools(self, tools, mock_check_valid_model):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async:
|
||||
completion = ChatCompletionOutput(
|
||||
choices=[
|
||||
ChatCompletionOutputComplete(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
arguments={"city": "Paris"}, name="weather", description=None
|
||||
),
|
||||
id="0",
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
created=1729074760,
|
||||
id="",
|
||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
|
||||
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
|
||||
)
|
||||
mock_chat_completion_async.return_value = completion
|
||||
|
||||
messages = [ChatMessage.from_user("What is the weather in Paris?")]
|
||||
response = await generator.run_async(messages=messages)
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
assert response["replies"][0].tool_calls[0].tool_name == "weather"
|
||||
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
|
||||
assert response["replies"][0].tool_calls[0].id == "0"
|
||||
assert response["replies"][0].meta == {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"model": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
|
||||
}
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("HF_API_TOKEN", None),
|
||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||
)
|
||||
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
||||
async def test_live_run_async_serverless(self):
|
||||
generator = HuggingFaceAPIChatGenerator(
|
||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
||||
generation_kwargs={"max_tokens": 20},
|
||||
)
|
||||
|
||||
messages = [ChatMessage.from_user("What is the capital of France?")]
|
||||
response = await generator.run_async(messages=messages)
|
||||
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
assert "usage" in response["replies"][0].meta
|
||||
assert "prompt_tokens" in response["replies"][0].meta["usage"]
|
||||
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user