2024-05-09 15:40:36 +02:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2024-04-09 17:47:13 +02:00
|
|
|
import os
|
2024-04-05 18:48:34 +02:00
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
from huggingface_hub import (
|
|
|
|
ChatCompletionOutput,
|
|
|
|
ChatCompletionStreamOutput,
|
2024-05-03 10:14:54 +02:00
|
|
|
ChatCompletionOutputComplete,
|
2024-04-05 18:48:34 +02:00
|
|
|
ChatCompletionStreamOutputChoice,
|
2024-05-03 10:14:54 +02:00
|
|
|
ChatCompletionOutputMessage,
|
2024-04-05 18:48:34 +02:00
|
|
|
ChatCompletionStreamOutputDelta,
|
|
|
|
)
|
|
|
|
from huggingface_hub.utils import RepositoryNotFoundError
|
|
|
|
|
|
|
|
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
|
|
|
from haystack.dataclasses import ChatMessage, StreamingChunk
|
|
|
|
from haystack.utils.auth import Secret
|
|
|
|
from haystack.utils.hf import HFGenerationAPIType
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_check_valid_model():
|
|
|
|
with patch(
|
|
|
|
"haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None)
|
|
|
|
) as mock:
|
|
|
|
yield mock
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_chat_completion():
|
|
|
|
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example
|
|
|
|
|
|
|
|
with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion:
|
|
|
|
completion = ChatCompletionOutput(
|
|
|
|
choices=[
|
2024-05-03 10:14:54 +02:00
|
|
|
ChatCompletionOutputComplete(
|
2024-04-05 18:48:34 +02:00
|
|
|
finish_reason="eos_token",
|
|
|
|
index=0,
|
2024-05-03 10:14:54 +02:00
|
|
|
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
|
2024-04-05 18:48:34 +02:00
|
|
|
)
|
|
|
|
],
|
2024-05-03 10:14:54 +02:00
|
|
|
id="some_id",
|
|
|
|
model="some_model",
|
|
|
|
object="some_object",
|
|
|
|
system_fingerprint="some_fingerprint",
|
|
|
|
usage={"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15},
|
2024-04-05 18:48:34 +02:00
|
|
|
created=1710498360,
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_chat_completion.return_value = completion
|
|
|
|
yield mock_chat_completion
|
|
|
|
|
|
|
|
|
|
|
|
# used to test serialization of streaming_callback
|
|
|
|
def streaming_callback_handler(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class TestHuggingFaceAPIGenerator:
|
|
|
|
def test_init_invalid_api_type(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={})
|
|
|
|
|
|
|
|
def test_init_serverless(self, mock_check_valid_model):
|
|
|
|
model = "HuggingFaceH4/zephyr-7b-alpha"
|
|
|
|
generation_kwargs = {"temperature": 0.6}
|
|
|
|
stop_words = ["stop"]
|
|
|
|
streaming_callback = None
|
|
|
|
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
|
|
api_params={"model": model},
|
|
|
|
token=None,
|
|
|
|
generation_kwargs=generation_kwargs,
|
|
|
|
stop_words=stop_words,
|
|
|
|
streaming_callback=streaming_callback,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
|
|
|
assert generator.api_params == {"model": model}
|
|
|
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
|
|
|
assert generator.streaming_callback == streaming_callback
|
|
|
|
|
|
|
|
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
|
|
|
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
|
|
|
with pytest.raises(RepositoryNotFoundError):
|
|
|
|
HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"}
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_init_serverless_no_model(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"}
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_init_tgi(self):
|
|
|
|
url = "https://some_model.com"
|
|
|
|
generation_kwargs = {"temperature": 0.6}
|
|
|
|
stop_words = ["stop"]
|
|
|
|
streaming_callback = None
|
|
|
|
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE,
|
|
|
|
api_params={"url": url},
|
|
|
|
token=None,
|
|
|
|
generation_kwargs=generation_kwargs,
|
|
|
|
stop_words=stop_words,
|
|
|
|
streaming_callback=streaming_callback,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE
|
|
|
|
assert generator.api_params == {"url": url}
|
|
|
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}}
|
|
|
|
assert generator.streaming_callback == streaming_callback
|
|
|
|
|
|
|
|
def test_init_tgi_invalid_url(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"}
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_init_tgi_no_url(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"}
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_to_dict(self, mock_check_valid_model):
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
2024-04-23 13:56:07 +02:00
|
|
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
2024-04-05 18:48:34 +02:00
|
|
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
|
|
|
generation_kwargs={"temperature": 0.6},
|
|
|
|
stop_words=["stop", "words"],
|
|
|
|
)
|
|
|
|
|
|
|
|
result = generator.to_dict()
|
|
|
|
init_params = result["init_parameters"]
|
|
|
|
|
2024-05-08 17:14:37 +02:00
|
|
|
assert init_params["api_type"] == "serverless_inference_api"
|
2024-04-23 13:56:07 +02:00
|
|
|
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
|
2024-04-05 18:48:34 +02:00
|
|
|
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
|
|
|
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
|
|
|
|
|
|
|
def test_from_dict(self, mock_check_valid_model):
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
2024-04-23 13:56:07 +02:00
|
|
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
2024-04-05 18:48:34 +02:00
|
|
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
|
|
|
generation_kwargs={"temperature": 0.6},
|
|
|
|
stop_words=["stop", "words"],
|
|
|
|
streaming_callback=streaming_callback_handler,
|
|
|
|
)
|
|
|
|
result = generator.to_dict()
|
|
|
|
|
|
|
|
# now deserialize, call from_dict
|
|
|
|
generator_2 = HuggingFaceAPIChatGenerator.from_dict(result)
|
|
|
|
assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
2024-04-23 13:56:07 +02:00
|
|
|
assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"}
|
2024-04-05 18:48:34 +02:00
|
|
|
assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False)
|
|
|
|
assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
|
|
|
assert generator_2.streaming_callback is streaming_callback_handler
|
|
|
|
|
|
|
|
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
|
|
|
|
self, mock_check_valid_model, mock_chat_completion, chat_messages
|
|
|
|
):
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
|
|
generation_kwargs={"temperature": 0.6},
|
|
|
|
stop_words=["stop", "words"],
|
|
|
|
streaming_callback=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
response = generator.run(messages=chat_messages)
|
|
|
|
|
|
|
|
# check kwargs passed to text_generation
|
|
|
|
_, kwargs = mock_chat_completion.call_args
|
|
|
|
assert kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
|
|
|
|
|
|
|
|
assert isinstance(response, dict)
|
|
|
|
assert "replies" in response
|
|
|
|
assert isinstance(response["replies"], list)
|
|
|
|
assert len(response["replies"]) == 1
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
|
|
|
|
def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages):
|
|
|
|
streaming_call_count = 0
|
|
|
|
|
|
|
|
# Define the streaming callback function
|
|
|
|
def streaming_callback_fn(chunk: StreamingChunk):
|
|
|
|
nonlocal streaming_call_count
|
|
|
|
streaming_call_count += 1
|
|
|
|
assert isinstance(chunk, StreamingChunk)
|
|
|
|
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
|
|
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
|
|
|
|
streaming_callback=streaming_callback_fn,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create a fake streamed response
|
|
|
|
# self needed here, don't remove
|
|
|
|
def mock_iter(self):
|
|
|
|
yield ChatCompletionStreamOutput(
|
|
|
|
choices=[
|
|
|
|
ChatCompletionStreamOutputChoice(
|
|
|
|
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
|
|
|
|
index=0,
|
|
|
|
finish_reason=None,
|
|
|
|
)
|
|
|
|
],
|
2024-05-03 10:14:54 +02:00
|
|
|
id="some_id",
|
|
|
|
model="some_model",
|
|
|
|
object="some_object",
|
|
|
|
system_fingerprint="some_fingerprint",
|
2024-04-05 18:48:34 +02:00
|
|
|
created=1710498504,
|
|
|
|
)
|
|
|
|
|
|
|
|
yield ChatCompletionStreamOutput(
|
|
|
|
choices=[
|
|
|
|
ChatCompletionStreamOutputChoice(
|
|
|
|
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
|
|
|
|
)
|
|
|
|
],
|
2024-05-03 10:14:54 +02:00
|
|
|
id="some_id",
|
|
|
|
model="some_model",
|
|
|
|
object="some_object",
|
|
|
|
system_fingerprint="some_fingerprint",
|
2024-04-05 18:48:34 +02:00
|
|
|
created=1710498504,
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_response = Mock(**{"__iter__": mock_iter})
|
|
|
|
mock_chat_completion.return_value = mock_response
|
|
|
|
|
|
|
|
# Generate text response with streaming callback
|
|
|
|
response = generator.run(chat_messages)
|
|
|
|
|
|
|
|
# check kwargs passed to text_generation
|
|
|
|
_, kwargs = mock_chat_completion.call_args
|
|
|
|
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
|
|
|
|
|
|
|
|
# Assert that the streaming callback was called twice
|
|
|
|
assert streaming_call_count == 2
|
|
|
|
|
|
|
|
# Assert that the response contains the generated replies
|
|
|
|
assert "replies" in response
|
|
|
|
assert isinstance(response["replies"], list)
|
|
|
|
assert len(response["replies"]) > 0
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
|
|
|
|
2024-04-09 17:47:13 +02:00
|
|
|
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
2024-04-05 18:48:34 +02:00
|
|
|
@pytest.mark.integration
|
2024-04-09 17:47:13 +02:00
|
|
|
@pytest.mark.skipif(
|
|
|
|
not os.environ.get("HF_API_TOKEN", None),
|
|
|
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
|
|
)
|
2024-04-05 18:48:34 +02:00
|
|
|
def test_run_serverless(self):
|
|
|
|
generator = HuggingFaceAPIChatGenerator(
|
|
|
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
|
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
|
|
|
generation_kwargs={"max_tokens": 20},
|
|
|
|
)
|
|
|
|
|
|
|
|
messages = [ChatMessage.from_user("What is the capital of France?")]
|
|
|
|
response = generator.run(messages=messages)
|
|
|
|
|
|
|
|
assert "replies" in response
|
|
|
|
assert isinstance(response["replies"], list)
|
|
|
|
assert len(response["replies"]) > 0
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|