haystack/test/components/generators/chat/test_hugging_face_tgi.py
Vladimir Blagojevic 9e6a2e3cf9
fix: HuggingFaceTGIGenerator gets stuck when model is not supported (#6915)
* HuggingFaceTGIGenerator/HuggingFaceTGIChatGenerator check if model is deployed on free-tier
2024-02-06 16:55:06 +01:00

356 lines
14 KiB
Python

from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.dataclasses import StreamingChunk, ChatMessage
@pytest.fixture
def mock_list_inference_deployed_models():
with patch(
"haystack.components.generators.chat.hugging_face_tgi.list_inference_deployed_models",
MagicMock(
return_value=[
"HuggingFaceH4/zephyr-7b-alpha",
"HuggingFaceH4/zephyr-7b-beta",
"mistralai/Mistral-7B-v0.1",
"meta-llama/Llama-2-13b-chat-hf",
]
),
) as mock:
yield mock
@pytest.fixture
def mock_check_valid_model():
with patch(
"haystack.components.generators.chat.hugging_face_tgi.check_valid_model", MagicMock(return_value=None)
) as mock:
yield mock
@pytest.fixture
def mock_text_generation():
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
mock_response = Mock()
mock_response.generated_text = "I'm fine, thanks."
details = Mock()
details.finish_reason = MagicMock(field1="value")
details.tokens = [1, 2, 3]
mock_response.details = details
mock_text_generation.return_value = mock_response
yield mock_text_generation
# used to test serialization of streaming_callback
def streaming_callback_handler(x):
return x
class TestHuggingFaceTGIChatGenerator:
def test_initialize_with_valid_model_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models
):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.tokenizer is not None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
def test_to_dict(self, mock_check_valid_model):
# Initialize the HuggingFaceTGIChatGenerator object with valid parameters
generator = HuggingFaceTGIChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
)
# Call the to_dict method
result = generator.to_dict()
init_params = result["init_parameters"]
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
)
# Call the to_dict method
result = generator.to_dict()
generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
# Assert that the tokenizer is now initialized
assert generator.tokenizer is not None
def test_warm_up_no_chat_template(
self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models, caplog
):
generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Set chat_template to None for this specific test
mock_auto_tokenizer.chat_template = None
generator.warm_up()
# warning message should be logged
assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text
def test_custom_chat_template(
self,
chat_messages,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
mock_list_inference_deployed_models,
):
custom_chat_template = "Here goes some Jinja template"
# mocked method to check if we called apply_chat_template with the custom template
mock_auto_tokenizer.apply_chat_template = MagicMock(return_value="some_value")
generator = HuggingFaceTGIChatGenerator(chat_template=custom_chat_template)
generator.warm_up()
assert generator.chat_template == custom_chat_template
generator.run(messages=chat_messages)
assert mock_auto_tokenizer.apply_chat_template.call_count == 1
# and we indeed called apply_chat_template with the custom template
_, kwargs = mock_auto_tokenizer.apply_chat_template.call_args
assert kwargs["chat_template"] == custom_chat_template
def test_initialize_with_invalid_model_path_or_url(self, mock_check_valid_model):
model = "invalid_model"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
mock_check_valid_model.side_effect = ValueError("Invalid model path or url")
with pytest.raises(ValueError):
HuggingFaceTGIChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
def test_initialize_with_invalid_url(self, mock_check_valid_model):
with pytest.raises(ValueError):
HuggingFaceTGIChatGenerator(model="NousResearch/Llama-2-7b-chat-hf", url="invalid_url")
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
# When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
with pytest.raises(RepositoryNotFoundError):
HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com")
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
model = "meta-llama/Llama-2-13b-chat-hf"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
response = generator.run(messages=chat_messages)
# check kwargs passed to text_generation
# note how n because it is not text generation parameter was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
model = "meta-llama/Llama-2-13b-chat-hf"
token = None
generation_kwargs = {"n": 3}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIChatGenerator(
model=model,
token=token,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
response = generator.run(chat_messages)
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
# note how n caused n replies to be generated
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 3
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_text_with_stop_words(
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
stop_words = ["stop", "words"]
# Generate text response with stop words
response = generator.run(chat_messages, generation_kwargs={"stop_words": stop_words})
# check kwargs passed to text_generation
# we translate stop_words to stop_sequences
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_generate_text_with_custom_generation_parameters(
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
# Create an instance of HuggingFaceRemoteGenerator with no generation parameters
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
# but then we pass them in run
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
response = generator.run(chat_messages, generation_kwargs=generation_kwargs)
# again check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
# Assert that the response contains the generated replies and the right response
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert response["replies"][0].content == "I'm fine, thanks."
def test_generate_text_with_streaming_callback(
self,
mock_check_valid_model,
mock_auto_tokenizer,
mock_text_generation,
chat_messages,
mock_list_inference_deployed_models,
):
streaming_call_count = 0
# Define the streaming callback function
def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
# Create an instance of HuggingFaceRemoteGenerator
generator = HuggingFaceTGIChatGenerator(streaming_callback=streaming_callback_fn)
generator.warm_up()
# Create a fake streamed response
# self needed here, don't remove
def mock_iter(self):
yield TextGenerationStreamResponse(
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
)
yield TextGenerationStreamResponse(
generated_text=None,
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
)
mock_response = Mock(**{"__iter__": mock_iter})
mock_text_generation.return_value = mock_response
# Generate text response with streaming callback
response = generator.run(chat_messages)
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]