mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-15 17:17:39 +00:00
feat: add run_async to HuggingfaceAPIChatGenerator (#8943)
* add run_async * add release notes * Add integration test
This commit is contained in:
parent
1b2053b358
commit
28db039bca
@ -3,10 +3,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
|
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||||
@ -15,6 +15,7 @@ from haystack.utils.url_validation import is_valid_http_url
|
|||||||
|
|
||||||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
|
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
|
AsyncInferenceClient,
|
||||||
ChatCompletionInputFunctionDefinition,
|
ChatCompletionInputFunctionDefinition,
|
||||||
ChatCompletionInputTool,
|
ChatCompletionInputTool,
|
||||||
ChatCompletionOutput,
|
ChatCompletionOutput,
|
||||||
@ -181,6 +182,7 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.streaming_callback = streaming_callback
|
self.streaming_callback = streaming_callback
|
||||||
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||||
|
self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
@ -250,7 +252,11 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
_check_duplicate_tool_names(tools)
|
_check_duplicate_tool_names(tools)
|
||||||
|
|
||||||
streaming_callback = streaming_callback or self.streaming_callback
|
# validate and select the streaming callback
|
||||||
|
streaming_callback = select_streaming_callback(
|
||||||
|
self.streaming_callback, streaming_callback, requires_async=False
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
if streaming_callback:
|
if streaming_callback:
|
||||||
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
|
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
|
||||||
|
|
||||||
@ -267,6 +273,63 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
]
|
]
|
||||||
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
|
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
|
||||||
|
|
||||||
|
@component.output_types(replies=List[ChatMessage])
|
||||||
|
async def run_async(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
tools: Optional[List[Tool]] = None,
|
||||||
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
|
||||||
|
|
||||||
|
This is the asynchronous version of the `run` method. It has the same parameters
|
||||||
|
and return values but can be used with `await` in an async code.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
A list of ChatMessage objects representing the input messages.
|
||||||
|
:param generation_kwargs:
|
||||||
|
Additional keyword arguments for text generation.
|
||||||
|
:param tools:
|
||||||
|
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
||||||
|
during component initialization.
|
||||||
|
:param streaming_callback:
|
||||||
|
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
||||||
|
parameter set during component initialization.
|
||||||
|
:returns: A dictionary with the following keys:
|
||||||
|
- `replies`: A list containing the generated responses as ChatMessage objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# update generation kwargs by merging with the default ones
|
||||||
|
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
|
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||||
|
|
||||||
|
tools = tools or self.tools
|
||||||
|
if tools and self.streaming_callback:
|
||||||
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
|
_check_duplicate_tool_names(tools)
|
||||||
|
|
||||||
|
# validate and select the streaming callback
|
||||||
|
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore
|
||||||
|
|
||||||
|
if streaming_callback:
|
||||||
|
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
|
||||||
|
|
||||||
|
hf_tools = None
|
||||||
|
if tools:
|
||||||
|
hf_tools = [
|
||||||
|
ChatCompletionInputTool(
|
||||||
|
function=ChatCompletionInputFunctionDefinition(
|
||||||
|
name=tool.name, description=tool.description, arguments=tool.parameters
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
|
||||||
|
|
||||||
def _run_streaming(
|
def _run_streaming(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
@ -359,3 +422,89 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
|
|
||||||
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
|
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
|
||||||
return {"replies": [message]}
|
return {"replies": [message]}
|
||||||
|
|
||||||
|
async def _run_streaming_async(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
generation_kwargs: Dict[str, Any],
|
||||||
|
streaming_callback: Callable[[StreamingChunk], None],
|
||||||
|
):
|
||||||
|
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
|
||||||
|
messages, stream=True, **generation_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
first_chunk_time = None
|
||||||
|
|
||||||
|
async for chunk in api_output:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
|
||||||
|
text = choice.delta.content or ""
|
||||||
|
generated_text += text
|
||||||
|
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
if finish_reason:
|
||||||
|
meta["finish_reason"] = finish_reason
|
||||||
|
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = datetime.now().isoformat()
|
||||||
|
|
||||||
|
stream_chunk = StreamingChunk(text, meta)
|
||||||
|
await streaming_callback(stream_chunk) # type: ignore
|
||||||
|
|
||||||
|
meta.update(
|
||||||
|
{
|
||||||
|
"model": self._async_client.model,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"index": 0,
|
||||||
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||||
|
"completion_start_time": first_chunk_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
|
||||||
|
return {"replies": [message]}
|
||||||
|
|
||||||
|
async def _run_non_streaming_async(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
generation_kwargs: Dict[str, Any],
|
||||||
|
tools: Optional[List["ChatCompletionInputTool"]] = None,
|
||||||
|
) -> Dict[str, List[ChatMessage]]:
|
||||||
|
api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
|
||||||
|
messages=messages, tools=tools, **generation_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(api_chat_output.choices) == 0:
|
||||||
|
return {"replies": []}
|
||||||
|
|
||||||
|
choice = api_chat_output.choices[0]
|
||||||
|
|
||||||
|
text = choice.message.content
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
if hfapi_tool_calls := choice.message.tool_calls:
|
||||||
|
for hfapi_tc in hfapi_tool_calls:
|
||||||
|
tool_call = ToolCall(
|
||||||
|
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
meta: Dict[str, Any] = {
|
||||||
|
"model": self._async_client.model,
|
||||||
|
"finish_reason": choice.finish_reason,
|
||||||
|
"index": choice.index,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
|
if api_chat_output.usage:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": api_chat_output.usage.prompt_tokens,
|
||||||
|
"completion_tokens": api_chat_output.usage.completion_tokens,
|
||||||
|
}
|
||||||
|
meta["usage"] = usage
|
||||||
|
|
||||||
|
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
|
||||||
|
return {"replies": [message]}
|
||||||
|
|||||||
@ -278,6 +278,7 @@ markers = [
|
|||||||
]
|
]
|
||||||
log_cli = true
|
log_cli = true
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "class"
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
warn_return_any = false
|
warn_return_any = false
|
||||||
|
|||||||
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add `run_async` method to HuggingFaceAPIChatGenerator. This method relies internally on the `AsyncInferenceClient` from huggingface
|
||||||
|
to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine
|
||||||
|
that can be awaited.
|
||||||
@ -3,7 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
@ -81,6 +81,31 @@ def mock_chat_completion():
|
|||||||
yield mock_chat_completion
|
yield mock_chat_completion
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_chat_completion_async():
|
||||||
|
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
||||||
|
completion = ChatCompletionOutput(
|
||||||
|
choices=[
|
||||||
|
ChatCompletionOutputComplete(
|
||||||
|
finish_reason="eos_token",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="some_id",
|
||||||
|
model="some_model",
|
||||||
|
system_fingerprint="some_fingerprint",
|
||||||
|
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
|
||||||
|
created=1710498360,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use AsyncMock to properly mock the async method
|
||||||
|
mock_chat_completion.return_value = completion
|
||||||
|
mock_chat_completion.__call__ = AsyncMock(return_value=completion)
|
||||||
|
|
||||||
|
yield mock_chat_completion
|
||||||
|
|
||||||
|
|
||||||
# used to test serialization of streaming_callback
|
# used to test serialization of streaming_callback
|
||||||
def streaming_callback_handler(x):
|
def streaming_callback_handler(x):
|
||||||
return x
|
return x
|
||||||
@ -112,6 +137,10 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert generator.streaming_callback == streaming_callback
|
assert generator.streaming_callback == streaming_callback
|
||||||
assert generator.tools is None
|
assert generator.tools is None
|
||||||
|
|
||||||
|
# check that client and async_client are initialized
|
||||||
|
assert generator._client.model == model
|
||||||
|
assert generator._async_client.model == model
|
||||||
|
|
||||||
def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
|
def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
|
||||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||||
generation_kwargs = {"temperature": 0.6}
|
generation_kwargs = {"temperature": 0.6}
|
||||||
@ -134,6 +163,9 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert generator.streaming_callback == streaming_callback
|
assert generator.streaming_callback == streaming_callback
|
||||||
assert generator.tools == tools
|
assert generator.tools == tools
|
||||||
|
|
||||||
|
assert generator._client.model == model
|
||||||
|
assert generator._async_client.model == model
|
||||||
|
|
||||||
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
||||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||||
with pytest.raises(RepositoryNotFoundError):
|
with pytest.raises(RepositoryNotFoundError):
|
||||||
@ -168,6 +200,9 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert generator.streaming_callback == streaming_callback
|
assert generator.streaming_callback == streaming_callback
|
||||||
assert generator.tools is None
|
assert generator.tools is None
|
||||||
|
|
||||||
|
assert generator._client.model == url
|
||||||
|
assert generator._async_client.model == url
|
||||||
|
|
||||||
def test_init_tgi_invalid_url(self):
|
def test_init_tgi_invalid_url(self):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
HuggingFaceAPIChatGenerator(
|
HuggingFaceAPIChatGenerator(
|
||||||
@ -623,3 +658,176 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert not final_message.tool_calls
|
assert not final_message.tool_calls
|
||||||
assert len(final_message.text) > 0
|
assert len(final_message.text) > 0
|
||||||
assert "paris" in final_message.text.lower()
|
assert "paris" in final_message.text.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
|
||||||
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
||||||
|
generation_kwargs={"temperature": 0.6},
|
||||||
|
stop_words=["stop", "words"],
|
||||||
|
streaming_callback=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await generator.run_async(messages=chat_messages)
|
||||||
|
|
||||||
|
# check kwargs passed to chat_completion
|
||||||
|
_, kwargs = mock_chat_completion_async.call_args
|
||||||
|
hf_messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
|
||||||
|
{"role": "user", "content": "Tell me about Berlin"},
|
||||||
|
]
|
||||||
|
assert kwargs == {
|
||||||
|
"temperature": 0.6,
|
||||||
|
"stop": ["stop", "words"],
|
||||||
|
"max_tokens": 512,
|
||||||
|
"tools": None,
|
||||||
|
"messages": hf_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert isinstance(response, dict)
|
||||||
|
assert "replies" in response
|
||||||
|
assert isinstance(response["replies"], list)
|
||||||
|
assert len(response["replies"]) == 1
|
||||||
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
|
||||||
|
streaming_call_count = 0
|
||||||
|
|
||||||
|
async def streaming_callback_fn(chunk: StreamingChunk):
|
||||||
|
nonlocal streaming_call_count
|
||||||
|
streaming_call_count += 1
|
||||||
|
assert isinstance(chunk, StreamingChunk)
|
||||||
|
|
||||||
|
# Create a fake streamed response
|
||||||
|
async def mock_aiter(self):
|
||||||
|
yield ChatCompletionStreamOutput(
|
||||||
|
choices=[
|
||||||
|
ChatCompletionStreamOutputChoice(
|
||||||
|
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
||||||
|
index=0,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="some_id",
|
||||||
|
model="some_model",
|
||||||
|
system_fingerprint="some_fingerprint",
|
||||||
|
created=1710498504,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionStreamOutput(
|
||||||
|
choices=[
|
||||||
|
ChatCompletionStreamOutputChoice(
|
||||||
|
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="some_id",
|
||||||
|
model="some_model",
|
||||||
|
system_fingerprint="some_fingerprint",
|
||||||
|
created=1710498504,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = Mock(**{"__aiter__": mock_aiter})
|
||||||
|
mock_chat_completion_async.return_value = mock_response
|
||||||
|
|
||||||
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
||||||
|
streaming_callback=streaming_callback_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await generator.run_async(messages=chat_messages)
|
||||||
|
|
||||||
|
# check kwargs passed to chat_completion
|
||||||
|
_, kwargs = mock_chat_completion_async.call_args
|
||||||
|
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
||||||
|
|
||||||
|
# Assert that the streaming callback was called twice
|
||||||
|
assert streaming_call_count == 2
|
||||||
|
|
||||||
|
# Assert that the response contains the generated replies
|
||||||
|
assert "replies" in response
|
||||||
|
assert isinstance(response["replies"], list)
|
||||||
|
assert len(response["replies"]) > 0
|
||||||
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_tools(self, tools, mock_check_valid_model):
|
||||||
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
|
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async:
|
||||||
|
completion = ChatCompletionOutput(
|
||||||
|
choices=[
|
||||||
|
ChatCompletionOutputComplete(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
arguments={"city": "Paris"}, name="weather", description=None
|
||||||
|
),
|
||||||
|
id="0",
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1729074760,
|
||||||
|
id="",
|
||||||
|
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||||
|
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
|
||||||
|
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
|
||||||
|
)
|
||||||
|
mock_chat_completion_async.return_value = completion
|
||||||
|
|
||||||
|
messages = [ChatMessage.from_user("What is the weather in Paris?")]
|
||||||
|
response = await generator.run_async(messages=messages)
|
||||||
|
|
||||||
|
assert isinstance(response, dict)
|
||||||
|
assert "replies" in response
|
||||||
|
assert isinstance(response["replies"], list)
|
||||||
|
assert len(response["replies"]) == 1
|
||||||
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||||
|
assert response["replies"][0].tool_calls[0].tool_name == "weather"
|
||||||
|
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
|
||||||
|
assert response["replies"][0].tool_calls[0].id == "0"
|
||||||
|
assert response["replies"][0].meta == {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"model": "meta-llama/Llama-3.1-70B-Instruct",
|
||||||
|
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
|
)
|
||||||
|
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
||||||
|
async def test_live_run_async_serverless(self):
|
||||||
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
||||||
|
generation_kwargs={"max_tokens": 20},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [ChatMessage.from_user("What is the capital of France?")]
|
||||||
|
response = await generator.run_async(messages=messages)
|
||||||
|
|
||||||
|
assert "replies" in response
|
||||||
|
assert isinstance(response["replies"], list)
|
||||||
|
assert len(response["replies"]) > 0
|
||||||
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||||
|
assert "usage" in response["replies"][0].meta
|
||||||
|
assert "prompt_tokens" in response["replies"][0].meta["usage"]
|
||||||
|
assert "completion_tokens" in response["replies"][0].meta["usage"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user