mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00

* Initial commit * Update deprecation version * Improve comment * Minor simplification * Add reno note * Remove deprecation warning * Remove fallback in haystack/components/generators/utils.py * FinishReason alphabetical import * Add tool_call_results finish reason, adapt codebase * Define finish_reason to be Optional[FinishReason] * Add StreamingChunk finish_reason in HF generators * Update reno note * Repair merge issue * Update tests for finish_reason * Resolve mypy issues * Lint issue * Enhance HF finish_reason translation * Remove irrlevant test * PR comments --------- Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
1151 lines
47 KiB
Python
1151 lines
47 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from datetime import datetime
|
|
import os
|
|
from typing import Any, Dict
|
|
from unittest.mock import MagicMock, Mock, AsyncMock, patch
|
|
|
|
import pytest
|
|
from haystack import Pipeline
|
|
from haystack.dataclasses import StreamingChunk
|
|
from haystack.utils.auth import Secret
|
|
from haystack.utils.hf import HFGenerationAPIType
|
|
|
|
from huggingface_hub import (
|
|
ChatCompletionOutput,
|
|
ChatCompletionOutputComplete,
|
|
ChatCompletionOutputFunctionDefinition,
|
|
ChatCompletionOutputMessage,
|
|
ChatCompletionOutputToolCall,
|
|
ChatCompletionOutputUsage,
|
|
ChatCompletionStreamOutput,
|
|
ChatCompletionStreamOutputChoice,
|
|
ChatCompletionStreamOutputDelta,
|
|
ChatCompletionInputStreamOptions,
|
|
ChatCompletionStreamOutputUsage,
|
|
)
|
|
from huggingface_hub.errors import RepositoryNotFoundError
|
|
|
|
from haystack.components.generators.chat.hugging_face_api import (
|
|
HuggingFaceAPIChatGenerator,
|
|
_convert_hfapi_tool_calls,
|
|
_convert_tools_to_hfapi_tools,
|
|
_convert_chat_completion_stream_output_to_streaming_chunk,
|
|
)
|
|
|
|
from haystack.tools import Tool
|
|
from haystack.dataclasses import ChatMessage, ToolCall
|
|
from haystack.tools.toolset import Toolset
|
|
|
|
|
|
@pytest.fixture
|
|
def chat_messages():
|
|
return [
|
|
ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"),
|
|
ChatMessage.from_user("Tell me about Berlin"),
|
|
]
|
|
|
|
|
|
def get_weather(city: str) -> Dict[str, Any]:
|
|
weather_info = {
|
|
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
|
|
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
|
|
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
|
|
}
|
|
return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
|
|
|
|
|
@pytest.fixture
|
|
def tools():
|
|
weather_tool = Tool(
|
|
name="weather",
|
|
description="useful to determine the weather in a given location",
|
|
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
|
|
function=get_weather,
|
|
)
|
|
return [weather_tool]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_check_valid_model():
|
|
with patch(
|
|
"haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None)
|
|
) as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_chat_completion():
|
|
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example
|
|
|
|
with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
|
completion = ChatCompletionOutput(
|
|
choices=[
|
|
ChatCompletionOutputComplete(
|
|
finish_reason="eos_token",
|
|
index=0,
|
|
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
|
|
created=1710498360,
|
|
)
|
|
|
|
mock_chat_completion.return_value = completion
|
|
yield mock_chat_completion
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_chat_completion_async():
|
|
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
|
completion = ChatCompletionOutput(
|
|
choices=[
|
|
ChatCompletionOutputComplete(
|
|
finish_reason="eos_token",
|
|
index=0,
|
|
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
|
|
created=1710498360,
|
|
)
|
|
|
|
# Use AsyncMock to properly mock the async method
|
|
mock_chat_completion.return_value = completion
|
|
mock_chat_completion.__call__ = AsyncMock(return_value=completion)
|
|
|
|
yield mock_chat_completion
|
|
|
|
|
|
# used to test serialization of streaming_callback
|
|
def streaming_callback_handler(x):
|
|
return x
|
|
|
|
|
|
class TestHuggingFaceAPIChatGenerator:
|
|
def test_init_invalid_api_type(self):
|
|
with pytest.raises(ValueError):
|
|
HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={})
|
|
|
|
def test_init_serverless(self, mock_check_valid_model):
|
|
model = "HuggingFaceH4/zephyr-7b-alpha"
|
|
generation_kwargs = {"temperature": 0.6}
|
|
stop_words = ["stop"]
|
|
streaming_callback = None
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": model},
|
|
token=None,
|
|
generation_kwargs=generation_kwargs,
|
|
stop_words=stop_words,
|
|
streaming_callback=streaming_callback,
|
|
)
|
|
|
|
assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
|
assert generator.api_params == {"model": model}
|
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
|
assert generator.streaming_callback == streaming_callback
|
|
assert generator.tools is None
|
|
|
|
# check that client and async_client are initialized
|
|
assert generator._client.model == model
|
|
assert generator._async_client.model == model
|
|
|
|
def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
|
|
model = "HuggingFaceH4/zephyr-7b-alpha"
|
|
generation_kwargs = {"temperature": 0.6}
|
|
stop_words = ["stop"]
|
|
streaming_callback = None
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": model},
|
|
token=None,
|
|
generation_kwargs=generation_kwargs,
|
|
stop_words=stop_words,
|
|
streaming_callback=streaming_callback,
|
|
tools=tools,
|
|
)
|
|
|
|
assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
|
assert generator.api_params == {"model": model}
|
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
|
assert generator.streaming_callback == streaming_callback
|
|
assert generator.tools == tools
|
|
|
|
assert generator._client.model == model
|
|
assert generator._async_client.model == model
|
|
|
|
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
|
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
|
with pytest.raises(RepositoryNotFoundError):
|
|
HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"}
|
|
)
|
|
|
|
def test_init_serverless_no_model(self):
|
|
with pytest.raises(ValueError):
|
|
HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"}
|
|
)
|
|
|
|
def test_init_tgi(self):
|
|
url = "https://some_model.com"
|
|
generation_kwargs = {"temperature": 0.6}
|
|
stop_words = ["stop"]
|
|
streaming_callback = None
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE,
|
|
api_params={"url": url},
|
|
token=None,
|
|
generation_kwargs=generation_kwargs,
|
|
stop_words=stop_words,
|
|
streaming_callback=streaming_callback,
|
|
)
|
|
|
|
assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE
|
|
assert generator.api_params == {"url": url}
|
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
|
assert generator.streaming_callback == streaming_callback
|
|
assert generator.tools is None
|
|
|
|
assert generator._client.model == url
|
|
assert generator._async_client.model == url
|
|
|
|
def test_init_tgi_invalid_url(self):
|
|
with pytest.raises(ValueError):
|
|
HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"}
|
|
)
|
|
|
|
def test_init_tgi_no_url(self):
|
|
with pytest.raises(ValueError):
|
|
HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"}
|
|
)
|
|
|
|
def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools):
|
|
duplicate_tools = [tools[0], tools[0]]
|
|
with pytest.raises(ValueError):
|
|
HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "irrelevant"},
|
|
tools=duplicate_tools,
|
|
)
|
|
|
|
def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools):
|
|
with pytest.raises(ValueError):
|
|
HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "irrelevant"},
|
|
tools=tools,
|
|
streaming_callback=streaming_callback_handler,
|
|
)
|
|
|
|
def test_to_dict(self, mock_check_valid_model):
|
|
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
|
generation_kwargs={"temperature": 0.6},
|
|
stop_words=["stop", "words"],
|
|
tools=[tool],
|
|
)
|
|
|
|
result = generator.to_dict()
|
|
init_params = result["init_parameters"]
|
|
|
|
assert init_params["api_type"] == "serverless_inference_api"
|
|
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
|
|
assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
|
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
|
assert init_params["streaming_callback"] is None
|
|
assert init_params["tools"] == [
|
|
{
|
|
"type": "haystack.tools.tool.Tool",
|
|
"data": {
|
|
"description": "description",
|
|
"function": "builtins.print",
|
|
"inputs_from_state": None,
|
|
"name": "name",
|
|
"outputs_to_state": None,
|
|
"outputs_to_string": None,
|
|
"parameters": {"x": {"type": "string"}},
|
|
},
|
|
}
|
|
]
|
|
|
|
def test_from_dict(self, mock_check_valid_model):
|
|
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
|
generation_kwargs={"temperature": 0.6},
|
|
stop_words=["stop", "words"],
|
|
tools=[tool],
|
|
)
|
|
result = generator.to_dict()
|
|
|
|
# now deserialize, call from_dict
|
|
generator_2 = HuggingFaceAPIChatGenerator.from_dict(result)
|
|
assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
|
assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"}
|
|
assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False)
|
|
assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
|
assert generator_2.streaming_callback is None
|
|
assert generator_2.tools == [tool]
|
|
|
|
def test_serde_in_pipeline(self, mock_check_valid_model):
|
|
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
|
generation_kwargs={"temperature": 0.6},
|
|
stop_words=["stop", "words"],
|
|
tools=[tool],
|
|
)
|
|
|
|
pipeline = Pipeline()
|
|
pipeline.add_component("generator", generator)
|
|
|
|
pipeline_dict = pipeline.to_dict()
|
|
assert pipeline_dict == {
|
|
"metadata": {},
|
|
"max_runs_per_component": 100,
|
|
"connection_type_validation": True,
|
|
"components": {
|
|
"generator": {
|
|
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",
|
|
"init_parameters": {
|
|
"api_type": "serverless_inference_api",
|
|
"api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"},
|
|
"token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False},
|
|
"generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512},
|
|
"streaming_callback": None,
|
|
"tools": [
|
|
{
|
|
"type": "haystack.tools.tool.Tool",
|
|
"data": {
|
|
"inputs_from_state": None,
|
|
"name": "name",
|
|
"outputs_to_state": None,
|
|
"outputs_to_string": None,
|
|
"description": "description",
|
|
"parameters": {"x": {"type": "string"}},
|
|
"function": "builtins.print",
|
|
},
|
|
}
|
|
],
|
|
},
|
|
}
|
|
},
|
|
"connections": [],
|
|
}
|
|
|
|
pipeline_yaml = pipeline.dumps()
|
|
|
|
new_pipeline = Pipeline.loads(pipeline_yaml)
|
|
assert new_pipeline == pipeline
|
|
|
|
def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
generation_kwargs={"temperature": 0.6},
|
|
stop_words=["stop", "words"],
|
|
streaming_callback=None,
|
|
)
|
|
|
|
response = generator.run(messages=chat_messages)
|
|
|
|
# check kwargs passed to chat_completion
|
|
_, kwargs = mock_chat_completion.call_args
|
|
hf_messages = [
|
|
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
|
|
{"role": "user", "content": "Tell me about Berlin"},
|
|
]
|
|
assert kwargs == {
|
|
"temperature": 0.6,
|
|
"stop": ["stop", "words"],
|
|
"max_tokens": 512,
|
|
"tools": None,
|
|
"messages": hf_messages,
|
|
}
|
|
|
|
assert isinstance(response, dict)
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) == 1
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages):
|
|
streaming_call_count = 0
|
|
|
|
# Define the streaming callback function
|
|
def streaming_callback_fn(chunk: StreamingChunk):
|
|
nonlocal streaming_call_count
|
|
streaming_call_count += 1
|
|
assert isinstance(chunk, StreamingChunk)
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
streaming_callback=streaming_callback_fn,
|
|
)
|
|
|
|
# Create a fake streamed response
|
|
# self needed here, don't remove
|
|
def mock_iter(self):
|
|
yield ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
|
index=0,
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
created=1710498504,
|
|
)
|
|
|
|
yield ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
created=1710498504,
|
|
)
|
|
|
|
mock_response = Mock(**{"__iter__": mock_iter})
|
|
mock_chat_completion.return_value = mock_response
|
|
|
|
# Generate text response with streaming callback
|
|
response = generator.run(chat_messages)
|
|
|
|
# check kwargs passed to text_generation
|
|
_, kwargs = mock_chat_completion.call_args
|
|
assert kwargs == {
|
|
"stop": [],
|
|
"stream": True,
|
|
"max_tokens": 512,
|
|
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
|
|
}
|
|
|
|
# Assert that the streaming callback was called twice
|
|
assert streaming_call_count == 2
|
|
|
|
# Assert that the response contains the generated replies
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) > 0
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
def test_run_with_streaming_callback_in_run_method(
|
|
self, mock_check_valid_model, mock_chat_completion, chat_messages
|
|
):
|
|
streaming_call_count = 0
|
|
|
|
# Define the streaming callback function
|
|
def streaming_callback_fn(chunk: StreamingChunk):
|
|
nonlocal streaming_call_count
|
|
streaming_call_count += 1
|
|
assert isinstance(chunk, StreamingChunk)
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
)
|
|
|
|
# Create a fake streamed response
|
|
# self needed here, don't remove
|
|
def mock_iter(self):
|
|
yield ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
|
index=0,
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
created=1710498504,
|
|
)
|
|
|
|
yield ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
created=1710498504,
|
|
)
|
|
|
|
mock_response = Mock(**{"__iter__": mock_iter})
|
|
mock_chat_completion.return_value = mock_response
|
|
|
|
# Generate text response with streaming callback
|
|
response = generator.run(chat_messages, streaming_callback=streaming_callback_fn)
|
|
|
|
# check kwargs passed to text_generation
|
|
_, kwargs = mock_chat_completion.call_args
|
|
assert kwargs == {
|
|
"stop": [],
|
|
"stream": True,
|
|
"max_tokens": 512,
|
|
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
|
|
}
|
|
|
|
# Assert that the streaming callback was called twice
|
|
assert streaming_call_count == 2
|
|
|
|
# Assert that the response contains the generated replies
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) > 0
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model):
|
|
component = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
streaming_callback=streaming_callback_handler,
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
message = ChatMessage.from_user("irrelevant")
|
|
component.run([message], tools=tools)
|
|
|
|
def test_run_with_tools(self, mock_check_valid_model, tools):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
|
|
tools=tools,
|
|
)
|
|
|
|
with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
|
completion = ChatCompletionOutput(
|
|
choices=[
|
|
ChatCompletionOutputComplete(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionOutputMessage(
|
|
role="assistant",
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionOutputToolCall(
|
|
function=ChatCompletionOutputFunctionDefinition(
|
|
arguments={"city": "Paris"}, name="weather", description=None
|
|
),
|
|
id="0",
|
|
type="function",
|
|
)
|
|
],
|
|
),
|
|
logprobs=None,
|
|
)
|
|
],
|
|
created=1729074760,
|
|
id="",
|
|
model="meta-llama/Llama-3.1-70B-Instruct",
|
|
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
|
|
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
|
|
)
|
|
mock_chat_completion.return_value = completion
|
|
|
|
messages = [ChatMessage.from_user("What is the weather in Paris?")]
|
|
response = generator.run(messages=messages)
|
|
|
|
assert isinstance(response, dict)
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) == 1
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
assert response["replies"][0].tool_calls[0].tool_name == "weather"
|
|
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
|
|
assert response["replies"][0].tool_calls[0].id == "0"
|
|
assert response["replies"][0].meta == {
|
|
"finish_reason": "stop",
|
|
"index": 0,
|
|
"model": "meta-llama/Llama-3.1-70B-Instruct",
|
|
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
|
|
}
|
|
|
|
def test_convert_hfapi_tool_calls_empty(self):
|
|
hfapi_tool_calls = None
|
|
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
|
|
assert len(tool_calls) == 0
|
|
|
|
hfapi_tool_calls = []
|
|
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
|
|
assert len(tool_calls) == 0
|
|
|
|
def test_convert_hfapi_tool_calls_dict_arguments(self):
|
|
hfapi_tool_calls = [
|
|
ChatCompletionOutputToolCall(
|
|
function=ChatCompletionOutputFunctionDefinition(
|
|
arguments={"city": "Paris"}, name="weather", description=None
|
|
),
|
|
id="0",
|
|
type="function",
|
|
)
|
|
]
|
|
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0].tool_name == "weather"
|
|
assert tool_calls[0].arguments == {"city": "Paris"}
|
|
assert tool_calls[0].id == "0"
|
|
|
|
def test_convert_hfapi_tool_calls_str_arguments(self):
|
|
hfapi_tool_calls = [
|
|
ChatCompletionOutputToolCall(
|
|
function=ChatCompletionOutputFunctionDefinition(
|
|
arguments='{"city": "Paris"}', name="weather", description=None
|
|
),
|
|
id="0",
|
|
type="function",
|
|
)
|
|
]
|
|
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0].tool_name == "weather"
|
|
assert tool_calls[0].arguments == {"city": "Paris"}
|
|
assert tool_calls[0].id == "0"
|
|
|
|
def test_convert_hfapi_tool_calls_invalid_str_arguments(self):
|
|
hfapi_tool_calls = [
|
|
ChatCompletionOutputToolCall(
|
|
function=ChatCompletionOutputFunctionDefinition(
|
|
arguments="not a valid JSON string", name="weather", description=None
|
|
),
|
|
id="0",
|
|
type="function",
|
|
)
|
|
]
|
|
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
|
|
assert len(tool_calls) == 0
|
|
|
|
def test_convert_hfapi_tool_calls_invalid_type_arguments(self):
|
|
hfapi_tool_calls = [
|
|
ChatCompletionOutputToolCall(
|
|
function=ChatCompletionOutputFunctionDefinition(
|
|
arguments=["this", "is", "a", "list"], name="weather", description=None
|
|
),
|
|
id="0",
|
|
type="function",
|
|
)
|
|
]
|
|
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
|
|
assert len(tool_calls) == 0
|
|
|
|
@pytest.mark.parametrize(
|
|
"hf_stream_output, expected_stream_chunk, dummy_previous_chunks",
|
|
[
|
|
(
|
|
ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(role="assistant", content=" Paris"), index=0
|
|
)
|
|
],
|
|
created=1748339326,
|
|
id="",
|
|
model="microsoft/Phi-3.5-mini-instruct",
|
|
system_fingerprint="3.2.1-sha-4d28897",
|
|
),
|
|
StreamingChunk(
|
|
content=" Paris",
|
|
meta={
|
|
"received_at": "2025-05-27T12:14:28.228852",
|
|
"model": "microsoft/Phi-3.5-mini-instruct",
|
|
"finish_reason": None,
|
|
},
|
|
index=0,
|
|
start=True,
|
|
),
|
|
[],
|
|
),
|
|
(
|
|
ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(role="assistant", content=""),
|
|
index=0,
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
created=1748339326,
|
|
id="",
|
|
model="microsoft/Phi-3.5-mini-instruct",
|
|
system_fingerprint="3.2.1-sha-4d28897",
|
|
),
|
|
StreamingChunk(
|
|
content="",
|
|
meta={
|
|
"received_at": "2025-05-27T12:14:28.228852",
|
|
"model": "microsoft/Phi-3.5-mini-instruct",
|
|
"finish_reason": "stop",
|
|
},
|
|
finish_reason="stop",
|
|
),
|
|
[0],
|
|
),
|
|
(
|
|
ChatCompletionStreamOutput(
|
|
choices=[],
|
|
created=1748339326,
|
|
id="",
|
|
model="microsoft/Phi-3.5-mini-instruct",
|
|
system_fingerprint="3.2.1-sha-4d28897",
|
|
usage=ChatCompletionStreamOutputUsage(completion_tokens=2, prompt_tokens=21, total_tokens=23),
|
|
),
|
|
StreamingChunk(
|
|
content="",
|
|
meta={
|
|
"received_at": "2025-05-27T12:14:28.228852",
|
|
"model": "microsoft/Phi-3.5-mini-instruct",
|
|
"usage": {"completion_tokens": 2, "prompt_tokens": 21},
|
|
},
|
|
),
|
|
[0, 1],
|
|
),
|
|
],
|
|
)
|
|
def test_convert_chat_completion_stream_output_to_streaming_chunk(
|
|
self, hf_stream_output, expected_stream_chunk, dummy_previous_chunks
|
|
):
|
|
converted_stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
|
|
chunk=hf_stream_output, previous_chunks=dummy_previous_chunks
|
|
)
|
|
# Remove timestamp from comparison since it's always the current time
|
|
converted_stream_chunk.meta.pop("received_at", None)
|
|
expected_stream_chunk.meta.pop("received_at", None)
|
|
assert converted_stream_chunk == expected_stream_chunk
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("HF_API_TOKEN", None),
|
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
)
|
|
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
|
def test_live_run_serverless(self):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"},
|
|
generation_kwargs={"max_tokens": 20},
|
|
)
|
|
|
|
# No need for instruction tokens here since we use the chat_completion endpoint which handles the chat
|
|
# templating for us.
|
|
messages = [
|
|
ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.")
|
|
]
|
|
response = generator.run(messages=messages)
|
|
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) > 0
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
assert response["replies"][0].text is not None
|
|
meta = response["replies"][0].meta
|
|
assert "usage" in meta
|
|
assert "prompt_tokens" in meta["usage"]
|
|
assert meta["usage"]["prompt_tokens"] > 0
|
|
assert "completion_tokens" in meta["usage"]
|
|
assert meta["usage"]["completion_tokens"] > 0
|
|
assert meta["model"] == "microsoft/Phi-3.5-mini-instruct"
|
|
assert meta["finish_reason"] is not None
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("HF_API_TOKEN", None),
|
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
)
|
|
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
|
def test_live_run_serverless_streaming(self):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"},
|
|
generation_kwargs={"max_tokens": 20},
|
|
streaming_callback=streaming_callback_handler,
|
|
)
|
|
|
|
# No need for instruction tokens here since we use the chat_completion endpoint which handles the chat
|
|
# templating for us.
|
|
messages = [
|
|
ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.")
|
|
]
|
|
response = generator.run(messages=messages)
|
|
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) > 0
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
assert response["replies"][0].text is not None
|
|
|
|
response_meta = response["replies"][0].meta
|
|
assert "completion_start_time" in response_meta
|
|
assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now()
|
|
assert "usage" in response_meta
|
|
assert "prompt_tokens" in response_meta["usage"]
|
|
assert response_meta["usage"]["prompt_tokens"] > 0
|
|
assert "completion_tokens" in response_meta["usage"]
|
|
assert response_meta["usage"]["completion_tokens"] > 0
|
|
assert response_meta["model"] == "microsoft/Phi-3.5-mini-instruct"
|
|
assert response_meta["finish_reason"] is not None
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("HF_API_TOKEN", None),
|
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
)
|
|
def test_live_run_with_tools(self, tools):
|
|
"""
|
|
We test the round trip: generate tool call, pass tool message, generate response.
|
|
|
|
The model used here (Qwen/Qwen2.5-72B-Instruct) is not gated and kept in a warm state.
|
|
"""
|
|
|
|
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "Qwen/Qwen2.5-72B-Instruct", "provider": "hf-inference"},
|
|
generation_kwargs={"temperature": 0.5},
|
|
)
|
|
|
|
results = generator.run(chat_messages, tools=tools)
|
|
assert len(results["replies"]) == 1
|
|
message = results["replies"][0]
|
|
|
|
assert message.tool_calls
|
|
tool_call = message.tool_call
|
|
assert isinstance(tool_call, ToolCall)
|
|
assert tool_call.tool_name == "weather"
|
|
assert "city" in tool_call.arguments
|
|
assert "Paris" in tool_call.arguments["city"]
|
|
assert message.meta["finish_reason"] == "stop"
|
|
|
|
new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)]
|
|
|
|
# the model tends to make tool calls if provided with tools, so we don't pass them here
|
|
results = generator.run(new_messages, generation_kwargs={"max_tokens": 50})
|
|
|
|
assert len(results["replies"]) == 1
|
|
final_message = results["replies"][0]
|
|
assert not final_message.tool_calls
|
|
assert len(final_message.text) > 0
|
|
assert "paris" in final_message.text.lower() and "22" in final_message.text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
generation_kwargs={"temperature": 0.6},
|
|
stop_words=["stop", "words"],
|
|
streaming_callback=None,
|
|
)
|
|
|
|
response = await generator.run_async(messages=chat_messages)
|
|
|
|
# check kwargs passed to chat_completion
|
|
_, kwargs = mock_chat_completion_async.call_args
|
|
hf_messages = [
|
|
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
|
|
{"role": "user", "content": "Tell me about Berlin"},
|
|
]
|
|
assert kwargs == {
|
|
"temperature": 0.6,
|
|
"stop": ["stop", "words"],
|
|
"max_tokens": 512,
|
|
"tools": None,
|
|
"messages": hf_messages,
|
|
}
|
|
|
|
assert isinstance(response, dict)
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) == 1
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages):
|
|
streaming_call_count = 0
|
|
|
|
async def streaming_callback_fn(chunk: StreamingChunk):
|
|
nonlocal streaming_call_count
|
|
streaming_call_count += 1
|
|
assert isinstance(chunk, StreamingChunk)
|
|
|
|
# Create a fake streamed response
|
|
async def mock_aiter(self):
|
|
yield ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
|
index=0,
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
created=1710498504,
|
|
)
|
|
|
|
yield ChatCompletionStreamOutput(
|
|
choices=[
|
|
ChatCompletionStreamOutputChoice(
|
|
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
|
)
|
|
],
|
|
id="some_id",
|
|
model="some_model",
|
|
system_fingerprint="some_fingerprint",
|
|
created=1710498504,
|
|
)
|
|
|
|
mock_response = Mock(**{"__aiter__": mock_aiter})
|
|
mock_chat_completion_async.return_value = mock_response
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
streaming_callback=streaming_callback_fn,
|
|
)
|
|
|
|
response = await generator.run_async(messages=chat_messages)
|
|
|
|
# check kwargs passed to chat_completion
|
|
_, kwargs = mock_chat_completion_async.call_args
|
|
assert kwargs == {
|
|
"stop": [],
|
|
"stream": True,
|
|
"max_tokens": 512,
|
|
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
|
|
}
|
|
|
|
# Assert that the streaming callback was called twice
|
|
assert streaming_call_count == 2
|
|
|
|
# Assert that the response contains the generated replies
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) > 0
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_async_with_tools(self, tools, mock_check_valid_model):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
|
|
tools=tools,
|
|
)
|
|
|
|
with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async:
|
|
completion = ChatCompletionOutput(
|
|
choices=[
|
|
ChatCompletionOutputComplete(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionOutputMessage(
|
|
role="assistant",
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionOutputToolCall(
|
|
function=ChatCompletionOutputFunctionDefinition(
|
|
arguments={"city": "Paris"}, name="weather", description=None
|
|
),
|
|
id="0",
|
|
type="function",
|
|
)
|
|
],
|
|
),
|
|
logprobs=None,
|
|
)
|
|
],
|
|
created=1729074760,
|
|
id="",
|
|
model="meta-llama/Llama-3.1-70B-Instruct",
|
|
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
|
|
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
|
|
)
|
|
mock_chat_completion_async.return_value = completion
|
|
|
|
messages = [ChatMessage.from_user("What is the weather in Paris?")]
|
|
response = await generator.run_async(messages=messages)
|
|
|
|
assert isinstance(response, dict)
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) == 1
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
assert response["replies"][0].tool_calls[0].tool_name == "weather"
|
|
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
|
|
assert response["replies"][0].tool_calls[0].id == "0"
|
|
assert response["replies"][0].meta == {
|
|
"finish_reason": "stop",
|
|
"index": 0,
|
|
"model": "meta-llama/Llama-3.1-70B-Instruct",
|
|
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
|
|
}
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("HF_API_TOKEN", None),
|
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
)
|
|
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
|
@pytest.mark.asyncio
|
|
async def test_live_run_async_serverless(self):
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"},
|
|
generation_kwargs={"max_tokens": 20},
|
|
)
|
|
|
|
messages = [
|
|
ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.")
|
|
]
|
|
try:
|
|
response = await generator.run_async(messages=messages)
|
|
|
|
assert "replies" in response
|
|
assert isinstance(response["replies"], list)
|
|
assert len(response["replies"]) > 0
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
assert response["replies"][0].text is not None
|
|
|
|
meta = response["replies"][0].meta
|
|
assert "usage" in meta
|
|
assert "prompt_tokens" in meta["usage"]
|
|
assert meta["usage"]["prompt_tokens"] > 0
|
|
assert "completion_tokens" in meta["usage"]
|
|
assert meta["usage"]["completion_tokens"] > 0
|
|
assert meta["model"] == "microsoft/Phi-3.5-mini-instruct"
|
|
assert meta["finish_reason"] is not None
|
|
finally:
|
|
await generator._async_client.close()
|
|
|
|
def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools):
|
|
"""Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset."""
|
|
toolset = Toolset(tools)
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
|
)
|
|
assert generator.tools == toolset
|
|
|
|
def test_from_dict_with_toolset(self, mock_check_valid_model, tools):
|
|
"""Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset."""
|
|
toolset = Toolset(tools)
|
|
component = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
|
)
|
|
data = component.to_dict()
|
|
|
|
deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data)
|
|
|
|
assert isinstance(deserialized_component.tools, Toolset)
|
|
assert len(deserialized_component.tools) == len(tools)
|
|
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
|
|
|
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
|
|
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
|
|
toolset = Toolset(tools[:1])
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
|
|
)
|
|
data = generator.to_dict()
|
|
|
|
expected_tools_data = {
|
|
"type": "haystack.tools.toolset.Toolset",
|
|
"data": {
|
|
"tools": [
|
|
{
|
|
"type": "haystack.tools.tool.Tool",
|
|
"data": {
|
|
"name": "weather",
|
|
"description": "useful to determine the weather in a given location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"city": {"type": "string"}},
|
|
"required": ["city"],
|
|
},
|
|
"function": "generators.chat.test_hugging_face_api.get_weather",
|
|
"outputs_to_string": None,
|
|
"inputs_from_state": None,
|
|
"outputs_to_state": None,
|
|
},
|
|
}
|
|
]
|
|
},
|
|
}
|
|
assert data["init_parameters"]["tools"] == expected_tools_data
|
|
|
|
def test_convert_tools_to_hfapi_tools(self):
|
|
assert _convert_tools_to_hfapi_tools(None) is None
|
|
assert _convert_tools_to_hfapi_tools([]) is None
|
|
|
|
tool = Tool(
|
|
name="weather",
|
|
description="useful to determine the weather in a given location",
|
|
parameters={"city": {"type": "string"}},
|
|
function=get_weather,
|
|
)
|
|
hf_tools = _convert_tools_to_hfapi_tools([tool])
|
|
assert len(hf_tools) == 1
|
|
assert hf_tools[0].type == "function"
|
|
assert hf_tools[0].function.name == "weather"
|
|
assert hf_tools[0].function.description == "useful to determine the weather in a given location"
|
|
assert hf_tools[0].function.parameters == {"city": {"type": "string"}}
|
|
|
|
def test_convert_tools_to_hfapi_tools_legacy(self):
|
|
# this satisfies the check hasattr(ChatCompletionInputFunctionDefinition, "arguments")
|
|
mock_class = MagicMock()
|
|
|
|
with patch(
|
|
"haystack.components.generators.chat.hugging_face_api.ChatCompletionInputFunctionDefinition", mock_class
|
|
):
|
|
tool = Tool(
|
|
name="weather",
|
|
description="useful to determine the weather in a given location",
|
|
parameters={"city": {"type": "string"}},
|
|
function=get_weather,
|
|
)
|
|
_convert_tools_to_hfapi_tools([tool])
|
|
|
|
mock_class.assert_called_once_with(
|
|
name="weather",
|
|
arguments={"city": {"type": "string"}},
|
|
description="useful to determine the weather in a given location",
|
|
)
|