feat: support for tools in HuggingFaceAPIChatGenerator (#8661)

* message conversion function

* hfapi w tools

* right test file + hf_hub version

* release note

* feedback
This commit is contained in:
Stefano Fiorucci 2024-12-19 15:04:37 +01:00 committed by GitHub
parent c306bee665
commit 2bc58d2987
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 507 additions and 82 deletions

View File

@ -14,7 +14,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)

View File

@ -11,7 +11,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)

View File

@ -5,30 +5,25 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
ChatCompletionInputTool,
ChatCompletionOutput,
ChatCompletionStreamOutput,
InferenceClient,
)
logger = logging.getLogger(__name__)
def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
"""
Convert a message to the format expected by Hugging Face APIs.
:returns: A dictionary with the following keys:
- `role`
- `content`
"""
return {"role": message.role.value, "content": message.text or ""}
@component
class HuggingFaceAPIChatGenerator:
"""
@ -107,6 +102,7 @@ class HuggingFaceAPIChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.
@ -121,14 +117,22 @@ class HuggingFaceAPIChatGenerator:
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
`TEXT_GENERATION_INFERENCE`.
:param token: The Hugging Face token to use as HTTP bearer authorization.
:param token:
The Hugging Face token to use as HTTP bearer authorization.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
:param generation_kwargs:
A dictionary with keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_p`.
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
:param stop_words:
An optional list of strings representing the stop words.
:param streaming_callback:
An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior.
"""
huggingface_hub_import.check()
@ -159,6 +163,11 @@ class HuggingFaceAPIChatGenerator:
msg = f"Unknown api_type {api_type}"
raise ValueError(msg)
if tools:
if streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
@ -171,6 +180,7 @@ class HuggingFaceAPIChatGenerator:
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
self.tools = tools
def to_dict(self) -> Dict[str, Any]:
"""
@ -180,6 +190,7 @@ class HuggingFaceAPIChatGenerator:
A dictionary containing the serialized component.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict(
self,
api_type=str(self.api_type),
@ -187,6 +198,7 @@ class HuggingFaceAPIChatGenerator:
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
tools=serialized_tools,
)
@classmethod
@ -195,6 +207,7 @@ class HuggingFaceAPIChatGenerator:
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
@ -202,12 +215,22 @@ class HuggingFaceAPIChatGenerator:
return default_from_dict(cls, data)
@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param messages:
A list of ChatMessage objects representing the input messages.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
@ -215,12 +238,22 @@ class HuggingFaceAPIChatGenerator:
# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages]
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
tools = tools or self.tools
if tools:
if self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)
return self._run_non_streaming(formatted_messages, generation_kwargs)
hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
@ -229,11 +262,17 @@ class HuggingFaceAPIChatGenerator:
generated_text = ""
for chunk in api_output: # pylint: disable=not-an-iterable
text = chunk.choices[0].delta.content
for chunk in api_output:
# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]
text = choice.delta.content
if text:
generated_text += text
finish_reason = chunk.choices[0].finish_reason
finish_reason = choice.finish_reason
meta = {}
if finish_reason:
@ -242,8 +281,7 @@ class HuggingFaceAPIChatGenerator:
stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
message = ChatMessage.from_assistant(generated_text)
message.meta.update(
meta.update(
{
"model": self._client.model,
"finish_reason": finish_reason,
@ -251,24 +289,48 @@ class HuggingFaceAPIChatGenerator:
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
}
)
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
return {"replies": [message]}
def _run_non_streaming(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
tools: Optional[List["ChatCompletionInputTool"]] = None,
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []
api_chat_output: ChatCompletionOutput = self._client.chat_completion(
messages=messages, tools=tools, **generation_kwargs
)
api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
for choice in api_chat_output.choices:
message = ChatMessage.from_assistant(choice.message.content)
message.meta.update(
{
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
}
)
chat_messages.append(message)
if len(api_chat_output.choices) == 0:
return {"replies": []}
return {"replies": chat_messages}
# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = api_chat_output.choices[0]
text = choice.message.content
tool_calls = []
if hfapi_tool_calls := choice.message.tool_calls:
for hfapi_tc in hfapi_tool_calls:
tool_call = ToolCall(
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
)
tool_calls.append(tool_call)
meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
usage = {"prompt_tokens": 0, "completion_tokens": 0}
if api_chat_output.usage:
usage = {
"prompt_tokens": api_chat_output.usage.prompt_tokens,
"completion_tokens": api_chat_output.usage.completion_tokens,
}
meta["usage"] = usage
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}

View File

@ -12,7 +12,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,

View File

@ -4,7 +4,7 @@
import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional
from pydantic import create_model
@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
del property_schema[key]
def _check_duplicate_tool_names(tools: List[Tool]) -> None:
"""
Check for duplicate tool names and raises a ValueError if they are found.
:param tools: The list of tools to check.
:raises ValueError: If duplicate tool names are found.
"""
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
"""
Deserialize Tools in a dictionary inplace.

View File

@ -8,7 +8,7 @@ from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union
from haystack import logging
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
@ -16,7 +16,7 @@ from haystack.utils.device import ComponentDevice
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import:
import torch
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError
@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
)
def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:
"""
Convert a message to the format expected by Hugging Face.
"""
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results
if not text_contents and not tool_calls and not tool_call_results:
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
if len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")
# HF always expects a content field, even if it is empty
hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""}
if tool_call_results:
result = tool_call_results[0]
hf_msg["content"] = result.result
if tc_id := result.origin.id:
hf_msg["tool_call_id"] = tc_id
# HF does not provide a way to communicate errors in tool invocations, so we ignore the error field
return hf_msg
if text_contents:
hf_msg["content"] = text_contents[0]
if tool_calls:
hf_tool_calls = []
for tc in tool_calls:
hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}}
if tc.id is not None:
hf_tool_call["id"] = tc.id
hf_tool_calls.append(hf_tool_call)
hf_msg["tool_calls"] = hf_tool_calls
return hf_msg
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer

View File

@ -85,7 +85,7 @@ extra-dependencies = [
"numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x
"transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub>=0.23.0", # Hugging Face API Generators and Embedders
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber

View File

@ -0,0 +1,4 @@
---
features:
- |
Add support for Tools in the Hugging Face API Chat Generator.

View File

@ -5,23 +5,46 @@ import os
from unittest.mock import MagicMock, Mock, patch
import pytest
from haystack import Pipeline
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
from haystack.utils.hf import HFGenerationAPIType
from huggingface_hub import (
ChatCompletionOutput,
ChatCompletionStreamOutput,
ChatCompletionOutputComplete,
ChatCompletionStreamOutputChoice,
ChatCompletionOutputFunctionDefinition,
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputUsage,
ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice,
ChatCompletionStreamOutputDelta,
)
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators.chat.hugging_face_api import (
HuggingFaceAPIChatGenerator,
_convert_message_to_hfapi_format,
)
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret
from haystack.utils.hf import HFGenerationAPIType
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
from haystack.dataclasses import ChatMessage, Tool, ToolCall
@pytest.fixture
def chat_messages():
return [
ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"),
ChatMessage.from_user("Tell me about Berlin"),
]
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=lambda x: x,
)
return [tool]
@pytest.fixture
@ -48,7 +71,7 @@ def mock_chat_completion():
id="some_id",
model="some_model",
system_fingerprint="some_fingerprint",
usage={"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15},
usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25),
created=1710498360,
)
@ -61,15 +84,7 @@ def streaming_callback_handler(x):
return x
def test_convert_message_to_hfapi_format():
message = ChatMessage.from_system("You are good assistant")
assert _convert_message_to_hfapi_format(message) == {"role": "system", "content": "You are good assistant"}
message = ChatMessage.from_user("I have a question")
assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"}
class TestHuggingFaceAPIGenerator:
class TestHuggingFaceAPIChatGenerator:
def test_init_invalid_api_type(self):
with pytest.raises(ValueError):
HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={})
@ -93,6 +108,29 @@ class TestHuggingFaceAPIGenerator:
assert generator.api_params == {"model": model}
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
assert generator.streaming_callback == streaming_callback
assert generator.tools is None
def test_init_serverless_with_tools(self, mock_check_valid_model, tools):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"temperature": 0.6}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": model},
token=None,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
tools=tools,
)
assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert generator.api_params == {"model": model}
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
assert generator.streaming_callback == streaming_callback
assert generator.tools == tools
def test_init_serverless_invalid_model(self, mock_check_valid_model):
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
@ -126,6 +164,7 @@ class TestHuggingFaceAPIGenerator:
assert generator.api_params == {"url": url}
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
assert generator.streaming_callback == streaming_callback
assert generator.tools is None
def test_init_tgi_invalid_url(self):
with pytest.raises(ValueError):
@ -139,12 +178,33 @@ class TestHuggingFaceAPIGenerator:
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"}
)
def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools):
duplicate_tools = [tools[0], tools[0]]
with pytest.raises(ValueError):
HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "irrelevant"},
tools=duplicate_tools,
)
def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools):
with pytest.raises(ValueError):
HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "irrelevant"},
tools=tools,
streaming_callback=streaming_callback_handler,
)
def test_to_dict(self, mock_check_valid_model):
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
tools=[tool],
)
result = generator.to_dict()
@ -154,15 +214,26 @@ class TestHuggingFaceAPIGenerator:
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
assert init_params["streaming_callback"] is None
assert init_params["tools"] == [
{
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
}
]
def test_from_dict(self, mock_check_valid_model):
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
tools=[tool],
)
result = generator.to_dict()
@ -172,11 +243,57 @@ class TestHuggingFaceAPIGenerator:
assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False)
assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler
assert generator_2.streaming_callback is None
assert generator_2.tools == [tool]
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_chat_completion, chat_messages
):
def test_serde_in_pipeline(self, mock_check_valid_model):
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
tools=[tool],
)
pipeline = Pipeline()
pipeline.add_component("generator", generator)
pipeline_dict = pipeline.to_dict()
assert pipeline_dict == {
"metadata": {},
"max_runs_per_component": 100,
"components": {
"generator": {
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",
"init_parameters": {
"api_type": "serverless_inference_api",
"api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"},
"token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False},
"generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512},
"streaming_callback": None,
"tools": [
{
"name": "name",
"description": "description",
"parameters": {"x": {"type": "string"}},
"function": "builtins.print",
}
],
},
}
},
"connections": [],
}
pipeline_yaml = pipeline.dumps()
new_pipeline = Pipeline.loads(pipeline_yaml)
assert new_pipeline == pipeline
def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
@ -187,9 +304,19 @@ class TestHuggingFaceAPIGenerator:
response = generator.run(messages=chat_messages)
# check kwargs passed to text_generation
# check kwargs passed to chat_completion
_, kwargs = mock_chat_completion.call_args
assert kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
hf_messages = [
{"role": "system", "content": "You are a helpful assistant speaking A2 level of English"},
{"role": "user", "content": "Tell me about Berlin"},
]
assert kwargs == {
"temperature": 0.6,
"stop": ["stop", "words"],
"max_tokens": 512,
"tools": None,
"messages": hf_messages,
}
assert isinstance(response, dict)
assert "replies" in response
@ -197,7 +324,7 @@ class TestHuggingFaceAPIGenerator:
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages):
def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages):
streaming_call_count = 0
# Define the streaming callback function
@ -260,13 +387,78 @@ class TestHuggingFaceAPIGenerator:
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.flaky(reruns=5, reruns_delay=5)
def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model):
component = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
streaming_callback=streaming_callback_handler,
)
with pytest.raises(ValueError):
message = ChatMessage.from_user("irrelevant")
component.run([message], tools=tools)
def test_run_with_tools(self, mock_check_valid_model, tools):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"},
tools=tools,
)
with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion:
completion = ChatCompletionOutput(
choices=[
ChatCompletionOutputComplete(
finish_reason="stop",
index=0,
message=ChatCompletionOutputMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionOutputToolCall(
function=ChatCompletionOutputFunctionDefinition(
arguments={"city": "Paris"}, name="weather", description=None
),
id="0",
type="function",
)
],
),
logprobs=None,
)
],
created=1729074760,
id="",
model="meta-llama/Llama-3.1-70B-Instruct",
system_fingerprint="2.3.2-dev0-sha-28bb7ae",
usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456),
)
mock_chat_completion.return_value = completion
messages = [ChatMessage.from_user("What is the weather in Paris?")]
response = generator.run(messages=messages)
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert response["replies"][0].tool_calls[0].tool_name == "weather"
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
assert response["replies"][0].tool_calls[0].id == "0"
assert response["replies"][0].meta == {
"finish_reason": "stop",
"index": 0,
"model": "meta-llama/Llama-3.1-70B-Instruct",
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_run_serverless(self):
def test_live_run_serverless(self):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
@ -284,13 +476,12 @@ class TestHuggingFaceAPIGenerator:
assert "prompt_tokens" in response["replies"][0].meta["usage"]
assert "completion_tokens" in response["replies"][0].meta["usage"]
@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_run_serverless_streaming(self):
def test_live_run_serverless_streaming(self):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
@ -308,3 +499,47 @@ class TestHuggingFaceAPIGenerator:
assert "usage" in response["replies"][0].meta
assert "prompt_tokens" in response["replies"][0].meta["usage"]
assert "completion_tokens" in response["replies"][0].meta["usage"]
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
@pytest.mark.integration
def test_live_run_with_tools(self, tools):
"""
We test the round trip: generate tool call, pass tool message, generate response.
The model used here (zephyr-7b-beta) is always available and not gated.
Even if it does not officially support tools, TGI+HF API make it work.
"""
chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")]
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
generation_kwargs={"temperature": 0.5},
)
results = generator.run(chat_messages, tools=tools)
assert len(results["replies"]) == 1
message = results["replies"][0]
assert message.tool_calls
tool_call = message.tool_call
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"
assert "city" in tool_call.arguments
assert "Paris" in tool_call.arguments["city"]
assert message.meta["finish_reason"] == "stop"
new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)]
# the model tends to make tool calls if provided with tools, so we don't pass them here
results = generator.run(new_messages, generation_kwargs={"max_tokens": 50})
assert len(results["replies"]) == 1
final_message = results["replies"][0]
assert not final_message.tool_calls
assert len(final_message.text) > 0
assert "paris" in final_message.text.lower()

View File

@ -12,6 +12,7 @@ from haystack.dataclasses.tool import (
ToolInvocationError,
_remove_title_from_schema,
deserialize_tools_inplace,
_check_duplicate_tool_names,
)
try:
@ -303,3 +304,18 @@ def test_remove_title_from_schema_handle_no_title_in_top_level():
"properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}},
"type": "object",
}
def test_check_duplicate_tool_names():
tools = [
Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report),
]
with pytest.raises(ValueError):
_check_duplicate_tool_names(tools)
tools = [
Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report),
]
_check_duplicate_tool_names(tools)

View File

@ -2,8 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from haystack.utils.hf import resolve_hf_device_map
import pytest
from haystack.utils.hf import resolve_hf_device_map, convert_message_to_hf_format
from haystack.utils.device import ComponentDevice
from haystack.dataclasses import ChatMessage, ToolCall, ChatRole, TextContent
def test_resolve_hf_device_map_only_device():
@ -23,3 +27,56 @@ def test_resolve_hf_device_map_device_and_device_map(caplog):
)
assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text
assert model_kwargs["device_map"] == "cuda:0"
def test_convert_message_to_hf_format():
message = ChatMessage.from_system("You are good assistant")
assert convert_message_to_hf_format(message) == {"role": "system", "content": "You are good assistant"}
message = ChatMessage.from_user("I have a question")
assert convert_message_to_hf_format(message) == {"role": "user", "content": "I have a question"}
message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"})
assert convert_message_to_hf_format(message) == {"role": "assistant", "content": "I have an answer"}
message = ChatMessage.from_assistant(
tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})]
)
assert convert_message_to_hf_format(message) == {
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}
],
}
message = ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather", arguments={"city": "Paris"})])
assert convert_message_to_hf_format(message) == {
"role": "assistant",
"content": "",
"tool_calls": [{"type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}],
}
tool_result = {"weather": "sunny", "temperature": "25"}
message = ChatMessage.from_tool(
tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})
)
assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"}
message = ChatMessage.from_tool(
tool_result=tool_result, origin=ToolCall(tool_name="weather", arguments={"city": "Paris"})
)
assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result}
def test_convert_message_to_hf_invalid():
message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[])
with pytest.raises(ValueError):
convert_message_to_hf_format(message)
message = ChatMessage(
_role=ChatRole.ASSISTANT,
_content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")],
)
with pytest.raises(ValueError):
convert_message_to_hf_format(message)