diff --git a/README.md b/README.md index 594e39e3b..688b38a96 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ First, run the minimal Haystack installation: pip install farm-haystack ``` -Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data: +Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data: ```python from haystack.document_stores import InMemoryDocumentStore diff --git a/haystack/preview/components/generators/chat/__init__.py b/haystack/preview/components/generators/chat/__init__.py new file mode 100644 index 000000000..b9885a6cf --- /dev/null +++ b/haystack/preview/components/generators/chat/__init__.py @@ -0,0 +1,3 @@ +from haystack.preview.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator + +__all__ = ["HuggingFaceTGIChatGenerator"] diff --git a/haystack/preview/components/generators/chat/hugging_face_tgi.py b/haystack/preview/components/generators/chat/hugging_face_tgi.py new file mode 100644 index 000000000..0e129915f --- /dev/null +++ b/haystack/preview/components/generators/chat/hugging_face_tgi.py @@ -0,0 +1,276 @@ +import logging +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Iterable, Callable +from urllib.parse import urlparse + +from huggingface_hub import InferenceClient +from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token +from transformers import AutoTokenizer + +from haystack.preview import component, default_to_dict, default_from_dict +from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params +from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler +from haystack.preview.dataclasses import ChatMessage, StreamingChunk + +logger = logging.getLogger(__name__) + + +class HuggingFaceTGIChatGenerator: + """ + Enables text generation using HuggingFace Hub hosted chat-based LLMs. This component is designed to seamlessly + inference chat-based models deployed on the Text Generation Inference (TGI) backend. + + You can use this component for chat LLMs hosted on Hugging Face inference endpoints, the rate-limited + Inference API tier: + + ```python + from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator + from haystack.preview.dataclasses import ChatMessage + + messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + + client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="") + client.warm_up() + response = client.run(messages, generation_kwargs={"max_new_tokens": 120}) + print(response) + ``` + + For chat LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint and/or your own custom TGI + endpoint, you'll need to provide the URL of the endpoint as well as a valid token: + + ```python + from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator + from haystack.preview.dataclasses import ChatMessage + + messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", + url="", + token="") + client.warm_up() + response = client.run(messages, generation_kwargs={"max_new_tokens": 120}) + print(response) + ``` + + Key Features and Compatibility: + - **Primary Compatibility**: Designed to work seamlessly with any chat-based model deployed using the TGI + framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference. + - **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face + inference endpoints. For more details, refer to https://huggingface.co/inference-endpoints. + - **Inference API Support**: Supports inference of TGI chat LLMs hosted on the rate-limited Inference + API tier. Learn more about the Inference API at https://huggingface.co/inference-api. + Discover available chat models using the following command: + ``` + wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat + ``` + and simply use the model ID as the model parameter for this component. You'll also need to provide a valid + Hugging Face API token as the token parameter. + - **Custom TGI Endpoints**: Supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can + deploy their own TGI endpoint using the TGI framework. For more details, refer + to https://huggingface.co/inference-endpoints. + + Input and Output Format: + - **ChatMessage Format**: This component uses the ChatMessage format to structure both input and output, + ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the + ChatMessage format can be found at https://github.com/openai/openai-python/blob/main/chatml.md. + + """ + + def __init__( + self, + model: str = "meta-llama/Llama-2-13b-chat-hf", + url: Optional[str] = None, + token: Optional[str] = None, + chat_template: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): + """ + Initialize the HuggingFaceTGIChatGenerator instance. + + :param model: A string representing the model path or URL. Default is "meta-llama/Llama-2-13b-chat-hf". + :param url: An optional string representing the URL of the TGI endpoint. + :param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat + messages. While high-quality and well-supported chat models typically include their own chat templates + accessible through their tokenizer, there are models that do not offer this feature. For such scenarios, + or if you wish to use a custom template instead of the model's default, you can use this parameter to + set your preferred chat template. + :param token: The Hugging Face token for HTTP bearer authorization. + You can find your HF token at https://huggingface.co/settings/tokens. + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,... + See Hugging Face's [documentation](https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters) + for more information. + :param stop_words: An optional list of strings representing the stop words. + :param streaming_callback: An optional callable for handling streaming responses. + """ + if url: + r = urlparse(url) + is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) + if not is_valid_url: + raise ValueError(f"Invalid TGI endpoint URL provided: {url}") + + check_valid_model(model, token) + + # handle generation kwargs setup + generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} + check_generation_params(generation_kwargs, ["n"]) + generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) + generation_kwargs["stop_sequences"].extend(stop_words or []) + + self.model = model + self.url = url + self.chat_template = chat_template + self.token = token + self.generation_kwargs = generation_kwargs + self.client = InferenceClient(url or model, token=token) + self.streaming_callback = streaming_callback + self.tokenizer = None + + def warm_up(self) -> None: + """ + Load the tokenizer. + """ + self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token) + # mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained + chat_template = getattr(self.tokenizer, "chat_template", None) + if not chat_template and not self.chat_template: + logger.warning( + "The model '%s' doesn't have a default chat_template, and no chat_template was supplied during " + "this component's initialization. It’s possible that the model doesn't support ChatML inference " + "format, potentially leading to unexpected behavior.", + self.model, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :return: A dictionary containing the serialized component. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + model=self.model, + url=self.url, + chat_template=self.chat_template, + token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIChatGenerator": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + # Don't send URL as it is sensitive information + return {"model": self.model} + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param messages: A list of ChatMessage instances representing the input messages. + :param generation_kwargs: Additional keyword arguments for text generation. + :return: A list containing the generated responses as ChatMessage instances. + """ + + # check generation kwargs given as parameters to override the default ones + additional_params = ["n", "stop_words"] + check_generation_params(generation_kwargs, additional_params) + + # update generation kwargs by merging with the default ones + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + num_responses = generation_kwargs.pop("n", 1) + + # merge stop_words and stop_sequences into a single list + generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) + generation_kwargs["stop_sequences"].extend(generation_kwargs.pop("stop_words", [])) + + if self.tokenizer is None: + raise RuntimeError("Please call warm_up() before running LLM inference.") + + # apply either model's chat template or the user-provided one + prepared_prompt: str = self.tokenizer.apply_chat_template( + conversation=messages, chat_template=self.chat_template, tokenize=False + ) + prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False)) + + if self.streaming_callback: + if num_responses > 1: + raise ValueError("Cannot stream multiple responses, please set n=1.") + + return self._run_streaming(prepared_prompt, prompt_token_count, generation_kwargs) + + return self._run_non_streaming(prepared_prompt, prompt_token_count, num_responses, generation_kwargs) + + def _run_streaming( + self, prepared_prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any] + ) -> Dict[str, List[ChatMessage]]: + res: Iterable[TextGenerationStreamResponse] = self.client.text_generation( + prepared_prompt, stream=True, details=True, **generation_kwargs + ) + chunk = None + # pylint: disable=not-an-iterable + for chunk in res: + token: Token = chunk.token + if token.special: + continue + chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} + stream_chunk = StreamingChunk(token.text, chunk_metadata) + self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) + + message = ChatMessage.from_assistant(chunk.generated_text) + message.metadata.update( + { + "finish_reason": chunk.details.finish_reason.value, + "index": 0, + "model": self.client.model, + "usage": { + "completion_tokens": chunk.details.generated_tokens, + "prompt_tokens": prompt_token_count, + "total_tokens": prompt_token_count + chunk.details.generated_tokens, + }, + } + ) + return {"replies": [message]} + + def _run_non_streaming( + self, prepared_prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any] + ) -> Dict[str, List[ChatMessage]]: + chat_messages: List[ChatMessage] = [] + for _i in range(num_responses): + tgr: TextGenerationResponse = self.client.text_generation( + prepared_prompt, details=True, **generation_kwargs + ) + message = ChatMessage.from_assistant(tgr.generated_text) + message.metadata.update( + { + "finish_reason": tgr.details.finish_reason.value, + "index": _i, + "model": self.client.model, + "usage": { + "completion_tokens": len(tgr.details.tokens), + "prompt_tokens": prompt_token_count, + "total_tokens": prompt_token_count + len(tgr.details.tokens), + }, + } + ) + chat_messages.append(message) + return {"replies": chat_messages} diff --git a/releasenotes/notes/add-huggingface-tgi-chat-c63f4879a5d81342.yaml b/releasenotes/notes/add-huggingface-tgi-chat-c63f4879a5d81342.yaml new file mode 100644 index 000000000..d5b177828 --- /dev/null +++ b/releasenotes/notes/add-huggingface-tgi-chat-c63f4879a5d81342.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Adds `HuggingFaceTGIChatGenerator` for text and chat generation. This components support remote inferencing for + Hugging Face LLMs via text-generation-inference (TGI) protocol. diff --git a/test/preview/components/generators/chat/__init__.py b/test/preview/components/generators/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/preview/components/generators/chat/conftest.py b/test/preview/components/generators/chat/conftest.py new file mode 100644 index 000000000..7a6e7a0fb --- /dev/null +++ b/test/preview/components/generators/chat/conftest.py @@ -0,0 +1,11 @@ +import pytest + +from haystack.preview.dataclasses import ChatMessage + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] diff --git a/test/preview/components/generators/chat/test_hugging_face_tgi.py b/test/preview/components/generators/chat/test_hugging_face_tgi.py new file mode 100644 index 000000000..35de29417 --- /dev/null +++ b/test/preview/components/generators/chat/test_hugging_face_tgi.py @@ -0,0 +1,317 @@ +from unittest.mock import patch, MagicMock, Mock + +import pytest +from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason +from huggingface_hub.utils import RepositoryNotFoundError + +from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator + +from haystack.preview.dataclasses import StreamingChunk, ChatMessage + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack.preview.components.generators.chat.hugging_face_tgi.check_valid_model", MagicMock(return_value=None) + ) as mock: + yield mock + + +@pytest.fixture +def mock_text_generation(): + with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: + mock_response = Mock() + mock_response.generated_text = "I'm fine, thanks." + details = Mock() + details.finish_reason = MagicMock(field1="value") + details.tokens = [1, 2, 3] + mock_response.details = details + mock_text_generation.return_value = mock_response + yield mock_text_generation + + +# used to test serialization of streaming_callback +def streaming_callback_handler(x): + return x + + +class TestHuggingFaceTGIChatGenerator: + @pytest.mark.unit + def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceTGIChatGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + generator.warm_up() + + assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} + assert generator.tokenizer is not None + assert generator.client is not None + assert generator.streaming_callback == streaming_callback + + @pytest.mark.unit + def test_to_dict(self, mock_check_valid_model): + # Initialize the HuggingFaceTGIChatGenerator object with valid parameters + generator = HuggingFaceTGIChatGenerator( + model="NousResearch/Llama-2-7b-chat-hf", + token="token", + generation_kwargs={"n": 5}, + stop_words=["stop", "words"], + streaming_callback=lambda x: x, + ) + + # Call the to_dict method + result = generator.to_dict() + init_params = result["init_parameters"] + + # Assert that the init_params dictionary contains the expected keys and values + assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf" + assert init_params["token"] is None + assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} + + @pytest.mark.unit + def test_from_dict(self, mock_check_valid_model): + generator = HuggingFaceTGIChatGenerator( + model="NousResearch/Llama-2-7b-chat-hf", + generation_kwargs={"n": 5}, + stop_words=["stop", "words"], + streaming_callback=streaming_callback_handler, + ) + # Call the to_dict method + result = generator.to_dict() + + generator_2 = HuggingFaceTGIChatGenerator.from_dict(result) + assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf" + assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} + assert generator_2.streaming_callback is streaming_callback_handler + + @pytest.mark.unit + def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer): + generator = HuggingFaceTGIChatGenerator() + generator.warm_up() + + # Assert that the tokenizer is now initialized + assert generator.tokenizer is not None + + @pytest.mark.unit + def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog): + generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") + + # Set chat_template to None for this specific test + mock_auto_tokenizer.chat_template = None + generator.warm_up() + + # warning message should be logged + assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text + + @pytest.mark.unit + def test_custom_chat_template( + self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation + ): + custom_chat_template = "Here goes some Jinja template" + + # mocked method to check if we called apply_chat_template with the custom template + mock_auto_tokenizer.apply_chat_template = MagicMock(return_value="some_value") + + generator = HuggingFaceTGIChatGenerator(chat_template=custom_chat_template) + generator.warm_up() + + assert generator.chat_template == custom_chat_template + + generator.run(messages=chat_messages) + assert mock_auto_tokenizer.apply_chat_template.call_count == 1 + + # and we indeed called apply_chat_template with the custom template + _, kwargs = mock_auto_tokenizer.apply_chat_template.call_args + assert kwargs["chat_template"] == custom_chat_template + + @pytest.mark.unit + def test_initialize_with_invalid_model_path_or_url(self, mock_check_valid_model): + model = "invalid_model" + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + mock_check_valid_model.side_effect = ValueError("Invalid model path or url") + + with pytest.raises(ValueError): + HuggingFaceTGIChatGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + @pytest.mark.unit + def test_initialize_with_invalid_url(self, mock_check_valid_model): + with pytest.raises(ValueError): + HuggingFaceTGIChatGenerator(model="NousResearch/Llama-2-7b-chat-hf", url="invalid_url") + + @pytest.mark.unit + def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model): + # When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com") + + @pytest.mark.unit + def test_generate_text_response_with_valid_prompt_and_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + ): + model = "meta-llama/Llama-2-13b-chat-hf" + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceTGIChatGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + generator.warm_up() + + response = generator.run(messages=chat_messages) + + # check kwargs passed to text_generation + # note how n because it is not text generation parameter was not passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": ["stop"]} + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.unit + def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + ): + model = "meta-llama/Llama-2-13b-chat-hf" + token = None + generation_kwargs = {"n": 3} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceTGIChatGenerator( + model=model, + token=token, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + generator.warm_up() + + response = generator.run(chat_messages) + + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": ["stop"]} + + # note how n caused n replies to be generated + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 3 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.unit + def test_generate_text_with_stop_words( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + ): + generator = HuggingFaceTGIChatGenerator() + generator.warm_up() + + stop_words = ["stop", "words"] + + # Generate text response with stop words + response = generator.run(chat_messages, generation_kwargs={"stop_words": stop_words}) + + # check kwargs passed to text_generation + # we translate stop_words to stop_sequences + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.unit + def test_generate_text_with_custom_generation_parameters( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + ): + # Create an instance of HuggingFaceRemoteGenerator with no generation parameters + generator = HuggingFaceTGIChatGenerator() + generator.warm_up() + + # but then we pass them in run + generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} + response = generator.run(chat_messages, generation_kwargs=generation_kwargs) + + # again check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8} + + # Assert that the response contains the generated replies and the right response + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].content == "I'm fine, thanks." + + @pytest.mark.unit + def test_generate_text_with_streaming_callback( + self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages + ): + streaming_call_count = 0 + + # Define the streaming callback function + def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + # Create an instance of HuggingFaceRemoteGenerator + generator = HuggingFaceTGIChatGenerator(streaming_callback=streaming_callback_fn) + generator.warm_up() + + # Create a fake streamed response + # self needed here, don't remove + def mock_iter(self): + yield TextGenerationStreamResponse( + generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False) + ) + yield TextGenerationStreamResponse( + generated_text=None, + token=Token(id=1, text="Ok bye", logprob=0.0, special=False), + details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5), + ) + + mock_response = Mock(**{"__iter__": mock_iter}) + mock_text_generation.return_value = mock_response + + # Generate text response with streaming callback + response = generator.run(chat_messages) + + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": [], "stream": True} + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] diff --git a/test/preview/components/generators/conftest.py b/test/preview/components/generators/conftest.py new file mode 100644 index 000000000..435b36ea0 --- /dev/null +++ b/test/preview/components/generators/conftest.py @@ -0,0 +1,21 @@ +from unittest.mock import patch, MagicMock + +import pytest + + +@pytest.fixture +def mock_auto_tokenizer(): + """ + In the original mock_auto_tokenizer fixture, we were mocking the transformers.AutoTokenizer.from_pretrained + method directly, but we were not providing a return value for this method. Therefore, when from_pretrained + was called within HuggingFaceTGIChatGenerator, it returned None because that's the default behavior of a + MagicMock object when a return value isn't specified. + + We will update the mock_auto_tokenizer fixture to return a MagicMock object when from_pretrained is called + in another PR. For now, we will use this fixture to mock the AutoTokenizer.from_pretrained method. + """ + + with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: + mock_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_tokenizer + yield mock_tokenizer diff --git a/test/preview/dataclasses/test_chat_message.py b/test/preview/dataclasses/test_chat_message.py index 285d38453..1c0ec71cf 100644 --- a/test/preview/dataclasses/test_chat_message.py +++ b/test/preview/dataclasses/test_chat_message.py @@ -1,4 +1,5 @@ import pytest +from transformers import AutoTokenizer from haystack.preview.dataclasses import ChatMessage, ChatRole @@ -29,19 +30,41 @@ def test_from_system_with_valid_content(): @pytest.mark.unit def test_with_empty_content(): - message = ChatMessage("", ChatRole.USER, None) + message = ChatMessage.from_user("") assert message.content == "" -@pytest.mark.unit -def test_with_invalid_role(): - with pytest.raises(TypeError): - ChatMessage("Invalid role", "invalid_role") - - @pytest.mark.unit def test_from_function_with_empty_name(): content = "Function call" message = ChatMessage.from_function(content, "") assert message.content == content assert message.name == "" + + +@pytest.mark.integration +def test_apply_chat_templating_on_chat_message(): + messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False) + assert tokenized_messages == "<|system|>\nYou are good assistant\n<|user|>\nI have a question\n" + + +@pytest.mark.integration +def test_apply_custom_chat_templating_on_chat_message(): + anthropic_template = ( + "{%- for message in messages %}" + "{%- if message.role == 'user' %}\n\nHuman: {{ message.content.strip() }}" + "{%- elif message.role == 'assistant' %}\n\nAssistant: {{ message.content.strip() }}" + "{%- elif message.role == 'function' %}{{ raise('anthropic does not support function calls.') }}" + "{%- elif message.role == 'system' and loop.index == 1 %}{{ message.content }}" + "{%- else %}{{ raise('Invalid message role: ' + message.role) }}" + "{%- endif %}" + "{%- endfor %}" + "\n\nAssistant:" + ) + messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] + # could be any tokenizer, let's use the one we already likely have in cache + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False) + assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"