2024-05-09 15:40:36 +02:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2024-04-05 18:05:43 +02:00
|
|
|
from unittest.mock import Mock, patch
|
2024-01-18 15:53:12 +01:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
|
|
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
|
|
|
from haystack.dataclasses import ChatMessage, ChatRole
|
|
|
|
from haystack.utils import ComponentDevice
|
2024-04-05 18:05:43 +02:00
|
|
|
from haystack.utils.auth import Secret
|
2024-01-18 15:53:12 +01:00
|
|
|
|
|
|
|
|
|
|
|
# used to test serialization of streaming_callback
|
|
|
|
def streaming_callback_handler(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def model_info_mock():
|
|
|
|
with patch(
|
|
|
|
"haystack.components.generators.chat.hugging_face_local.model_info",
|
|
|
|
new=Mock(return_value=Mock(pipeline_tag="text2text-generation")),
|
|
|
|
) as mock:
|
|
|
|
yield mock
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_pipeline_tokenizer():
|
|
|
|
# Mocking the pipeline
|
|
|
|
mock_pipeline = Mock(return_value=[{"generated_text": "Berlin is cool"}])
|
|
|
|
|
|
|
|
# Mocking the tokenizer
|
|
|
|
mock_tokenizer = Mock(spec=PreTrainedTokenizer)
|
|
|
|
mock_tokenizer.encode.return_value = ["Berlin", "is", "cool"]
|
|
|
|
mock_pipeline.tokenizer = mock_tokenizer
|
|
|
|
|
|
|
|
return mock_pipeline
|
|
|
|
|
|
|
|
|
|
|
|
class TestHuggingFaceLocalChatGenerator:
|
|
|
|
def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock):
|
|
|
|
model = "HuggingFaceH4/zephyr-7b-alpha"
|
|
|
|
generation_kwargs = {"n": 1}
|
|
|
|
stop_words = ["stop"]
|
|
|
|
streaming_callback = None
|
|
|
|
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
model=model,
|
|
|
|
generation_kwargs=generation_kwargs,
|
|
|
|
stop_words=stop_words,
|
|
|
|
streaming_callback=streaming_callback,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
|
|
|
assert generator.streaming_callback == streaming_callback
|
|
|
|
|
|
|
|
def test_init_custom_token(self):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
task="text2text-generation",
|
2024-02-05 13:17:01 +01:00
|
|
|
token=Secret.from_token("test-token"),
|
2024-01-18 15:53:12 +01:00
|
|
|
device=ComponentDevice.from_str("cpu"),
|
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.huggingface_pipeline_kwargs == {
|
|
|
|
"model": "mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
"task": "text2text-generation",
|
|
|
|
"token": "test-token",
|
|
|
|
"device": "cpu",
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_init_custom_device(self):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
task="text2text-generation",
|
|
|
|
device=ComponentDevice.from_str("cpu"),
|
2024-02-05 13:17:01 +01:00
|
|
|
token=None,
|
2024-01-18 15:53:12 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.huggingface_pipeline_kwargs == {
|
|
|
|
"model": "mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
"task": "text2text-generation",
|
|
|
|
"token": None,
|
|
|
|
"device": "cpu",
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_init_task_parameter(self):
|
2024-02-05 13:17:01 +01:00
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
|
|
|
|
)
|
2024-01-18 15:53:12 +01:00
|
|
|
|
|
|
|
assert generator.huggingface_pipeline_kwargs == {
|
|
|
|
"model": "HuggingFaceH4/zephyr-7b-beta",
|
|
|
|
"task": "text2text-generation",
|
|
|
|
"token": None,
|
|
|
|
"device": "cpu",
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_init_task_in_huggingface_pipeline_kwargs(self):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
2024-02-05 13:17:01 +01:00
|
|
|
huggingface_pipeline_kwargs={"task": "text2text-generation"},
|
|
|
|
device=ComponentDevice.from_str("cpu"),
|
|
|
|
token=None,
|
2024-01-18 15:53:12 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.huggingface_pipeline_kwargs == {
|
|
|
|
"model": "HuggingFaceH4/zephyr-7b-beta",
|
|
|
|
"task": "text2text-generation",
|
|
|
|
"token": None,
|
|
|
|
"device": "cpu",
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_init_task_inferred_from_model_name(self, model_info_mock):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
2024-02-05 13:17:01 +01:00
|
|
|
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu"), token=None
|
2024-01-18 15:53:12 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
assert generator.huggingface_pipeline_kwargs == {
|
|
|
|
"model": "mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
"task": "text2text-generation",
|
|
|
|
"token": None,
|
|
|
|
"device": "cpu",
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_init_invalid_task(self):
|
|
|
|
with pytest.raises(ValueError, match="is not supported."):
|
|
|
|
HuggingFaceLocalChatGenerator(task="text-classification")
|
|
|
|
|
|
|
|
def test_to_dict(self, model_info_mock):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
model="NousResearch/Llama-2-7b-chat-hf",
|
2024-02-05 13:17:01 +01:00
|
|
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
2024-01-18 15:53:12 +01:00
|
|
|
generation_kwargs={"n": 5},
|
|
|
|
stop_words=["stop", "words"],
|
|
|
|
streaming_callback=lambda x: x,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Call the to_dict method
|
|
|
|
result = generator.to_dict()
|
|
|
|
init_params = result["init_parameters"]
|
|
|
|
|
|
|
|
# Assert that the init_params dictionary contains the expected keys and values
|
2024-02-05 13:17:01 +01:00
|
|
|
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
2024-01-18 15:53:12 +01:00
|
|
|
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
|
|
|
assert "token" not in init_params["huggingface_pipeline_kwargs"]
|
|
|
|
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
|
|
|
|
|
|
|
def test_from_dict(self, model_info_mock):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
model="NousResearch/Llama-2-7b-chat-hf",
|
|
|
|
generation_kwargs={"n": 5},
|
|
|
|
stop_words=["stop", "words"],
|
|
|
|
streaming_callback=streaming_callback_handler,
|
|
|
|
)
|
|
|
|
# Call the to_dict method
|
|
|
|
result = generator.to_dict()
|
|
|
|
|
|
|
|
generator_2 = HuggingFaceLocalChatGenerator.from_dict(result)
|
|
|
|
|
2024-06-27 10:31:58 +02:00
|
|
|
assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
2024-01-18 15:53:12 +01:00
|
|
|
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
|
|
|
assert generator_2.streaming_callback is streaming_callback_handler
|
|
|
|
|
|
|
|
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
|
2024-04-05 18:05:43 +02:00
|
|
|
def test_warm_up(self, pipeline_mock, monkeypatch):
|
|
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
2024-12-13 09:50:23 +01:00
|
|
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
2024-01-18 15:53:12 +01:00
|
|
|
generator = HuggingFaceLocalChatGenerator(
|
|
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
task="text2text-generation",
|
|
|
|
device=ComponentDevice.from_str("cpu"),
|
|
|
|
)
|
|
|
|
|
|
|
|
pipeline_mock.assert_not_called()
|
|
|
|
|
|
|
|
generator.warm_up()
|
|
|
|
|
|
|
|
pipeline_mock.assert_called_once_with(
|
|
|
|
model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", token=None, device="cpu"
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
|
|
|
|
|
|
|
|
# Use the mocked pipeline from the fixture and simulate warm_up
|
|
|
|
generator.pipeline = mock_pipeline_tokenizer
|
|
|
|
|
|
|
|
results = generator.run(messages=chat_messages)
|
|
|
|
|
|
|
|
assert "replies" in results
|
|
|
|
assert isinstance(results["replies"][0], ChatMessage)
|
|
|
|
chat_message = results["replies"][0]
|
|
|
|
assert chat_message.is_from(ChatRole.ASSISTANT)
|
2024-11-28 11:16:07 +01:00
|
|
|
assert chat_message.text == "Berlin is cool"
|
2024-01-18 15:53:12 +01:00
|
|
|
|
|
|
|
def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
|
|
|
|
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
|
|
|
|
|
|
|
|
# Use the mocked pipeline from the fixture and simulate warm_up
|
|
|
|
generator.pipeline = mock_pipeline_tokenizer
|
|
|
|
|
|
|
|
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
|
|
|
|
|
|
|
|
# Use the mocked pipeline from the fixture and simulate warm_up
|
|
|
|
generator.pipeline = mock_pipeline_tokenizer
|
|
|
|
results = generator.run(messages=chat_messages, generation_kwargs=generation_kwargs)
|
|
|
|
|
|
|
|
# check kwargs passed pipeline
|
|
|
|
_, kwargs = generator.pipeline.call_args
|
|
|
|
assert kwargs["max_new_tokens"] == 100
|
|
|
|
assert kwargs["temperature"] == 0.8
|
|
|
|
|
|
|
|
# replies are properly parsed and returned
|
|
|
|
assert "replies" in results
|
|
|
|
assert isinstance(results["replies"][0], ChatMessage)
|
|
|
|
chat_message = results["replies"][0]
|
|
|
|
assert chat_message.is_from(ChatRole.ASSISTANT)
|
2024-11-28 11:16:07 +01:00
|
|
|
assert chat_message.text == "Berlin is cool"
|