diff --git a/haystack/components/generators/chat/__init__.py b/haystack/components/generators/chat/__init__.py index 028389a56..5400e158a 100644 --- a/haystack/components/generators/chat/__init__.py +++ b/haystack/components/generators/chat/__init__.py @@ -1,5 +1,13 @@ +from haystack.components.generators.chat.hugging_face_local import HuggingFaceLocalChatGenerator from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator -__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator", "AzureOpenAIChatGenerator"] + +__all__ = [ + "HuggingFaceLocalChatGenerator", + "HuggingFaceTGIChatGenerator", + "OpenAIChatGenerator", + "GPTChatGenerator", + "AzureOpenAIChatGenerator", +] diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py new file mode 100644 index 000000000..4291e1988 --- /dev/null +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -0,0 +1,327 @@ +import logging +import sys +from typing import Any, Dict, List, Literal, Optional, Union, Callable + +from haystack.components.generators.hf_utils import PIPELINE_SUPPORTED_TASKS + +from haystack import component, default_to_dict, default_from_dict +from haystack.components.generators.hf_utils import HFTokenStreamingHandler +from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler +from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice + +logger = logging.getLogger(__name__) + +with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: + from huggingface_hub import model_info + from transformers import StoppingCriteriaList, pipeline, PreTrainedTokenizer, PreTrainedTokenizerFast + from haystack.components.generators.hf_utils import StopWordsCriteria # pylint: disable=ungrouped-imports + from haystack.utils.hf import serialize_hf_model_kwargs, deserialize_hf_model_kwargs + + +@component +class HuggingFaceLocalChatGenerator: + """ + + The `HuggingFaceLocalChatGenerator` class is a component designed for generating chat responses using models from + Hugging Face's model hub. It is tailored for local runtime text generation tasks and provides a convenient interface + for working with chat-based models, such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf` + etc. + + Usage example: + ```python + from haystack.components.generators.chat import HuggingFaceLocalChatGenerator + from haystack.dataclasses import ChatMessage + + generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta") + generator.warm_up() + messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")] + print(generator.run(messages)) + + # {'replies': [ChatMessage(content=' Natural Language Processing (NLP) is a subfield of artificial + intelligence that deals with the interaction between computers and human language. It enables computers + to understand, interpret, and generate human language in a valuable way. NLP involves various techniques + such as speech recognition, text analysis, sentiment analysis, and machine translation. The ultimate goal + is to make it easier for computers to process and derive meaning from human language, improving communication + between humans and machines.', role=, name=None, + meta={'finish_reason': 'stop', 'index': 0, 'model': 'mistralai/Mistral-7B-Instruct-v0.2', + 'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})]} + ``` + """ + + def __init__( + self, + model: str = "HuggingFaceH4/zephyr-7b-beta", + task: Optional[Literal["text-generation", "text2text-generation"]] = None, + device: Optional[ComponentDevice] = None, + token: Optional[Union[str, bool]] = None, + chat_template: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): + """ + :param model: The name or path of a Hugging Face model for text generation, + for example, mistralai/Mistral-7B-Instruct-v0.2,T TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ, etc. + The important aspect of the model is that it should be a chat model and that it supports ChatML messaging + format. + If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. + :param task: The task for the Hugging Face pipeline. + Possible values are "text-generation" and "text2text-generation". + Generally, decoder-only models like GPT support "text-generation", + while encoder-decoder models like T5 support "text2text-generation". + If the task is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. + If not specified, the component will attempt to infer the task from the model name, + calling the Hugging Face Hub API. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. + :param token: The token to use as HTTP bearer authorization for remote files. + If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface). + If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. + :param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat + messages. While high-quality and well-supported chat models typically include their own chat templates + accessible through their tokenizer, there are models that do not offer this feature. For such scenarios, + or if you wish to use a custom template instead of the model's default, you can use this parameter to + set your preferred chat template. + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,... + See Hugging Face's documentation for more information: + - https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation + - https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig + - The only generation_kwargs we set by default is max_new_tokens, which is set to 512 tokens. + :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the + Hugging Face pipeline for text generation. + These keyword arguments provide fine-grained control over the Hugging Face pipeline. + In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters. + See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task) + for more information on the available kwargs. + In this dictionary, you can also include `model_kwargs` to specify the kwargs + for model initialization: + https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained + :param stop_words: A list of stop words. If any one of the stop words is generated, the generation is stopped. + If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`. + For some chat models, the output includes both the new text and the original prompt. + In these cases, it's important to make sure your prompt has no stop words. + :param streaming_callback: An optional callable for handling streaming responses. + """ + torch_and_transformers_import.check() + + huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} + generation_kwargs = generation_kwargs or {} + + # check if the huggingface_pipeline_kwargs contain the essential parameters + # otherwise, populate them with values from other init parameters + huggingface_pipeline_kwargs.setdefault("model", model) + huggingface_pipeline_kwargs.setdefault("token", token) + + device = ComponentDevice.resolve_device(device) + device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False) + + # task identification and validation + if task is None: + if "task" in huggingface_pipeline_kwargs: + task = huggingface_pipeline_kwargs["task"] + elif isinstance(huggingface_pipeline_kwargs["model"], str): + task = model_info( + huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"] + ).pipeline_tag + + if task not in PIPELINE_SUPPORTED_TASKS: + raise ValueError( + f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}." + ) + huggingface_pipeline_kwargs["task"] = task + + # if not specified, set return_full_text to False for text-generation + # only generated text is returned (excluding prompt) + if task == "text-generation": + generation_kwargs.setdefault("return_full_text", False) + + if stop_words and "stopping_criteria" in generation_kwargs: + raise ValueError( + "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. " + "Please specify only one of them." + ) + generation_kwargs.setdefault("max_new_tokens", 512) + generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) + generation_kwargs["stop_sequences"].extend(stop_words or []) + + self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs + self.generation_kwargs = generation_kwargs + self.chat_template = chat_template + self.streaming_callback = streaming_callback + self.pipeline = None + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + if isinstance(self.huggingface_pipeline_kwargs["model"], str): + return {"model": self.huggingface_pipeline_kwargs["model"]} + return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} + + def warm_up(self): + if self.pipeline is None: + self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + serialization_dict = default_to_dict( + self, + huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + ) + + huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] + # we don't want to serialize valid tokens + if isinstance(huggingface_pipeline_kwargs["token"], str): + serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token") + + serialize_hf_model_kwargs(huggingface_pipeline_kwargs) + return serialization_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator": + """ + Deserialize this component from a dictionary. + """ + torch_and_transformers_import.check() # leave this, cls method + + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + + huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {}) + deserialize_hf_model_kwargs(huggingface_pipeline_kwargs) + return default_from_dict(cls, data) + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke text generation inference based on the provided messages and generation parameters. + + :param messages: A list of ChatMessage instances representing the input messages. + :param generation_kwargs: Additional keyword arguments for text generation. + :return: A list containing the generated responses as ChatMessage instances. + """ + if self.pipeline is None: + raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") + + tokenizer = self.pipeline.tokenizer + + # Check and update generation parameters + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", []) + # pipeline call doesn't support stop_sequences, so we need to pop it + stop_words = self._validate_stop_words(stop_words) + + # Set up stop words criteria if stop words exist + stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None + if stop_words_criteria: + generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria]) + + if self.streaming_callback: + num_responses = generation_kwargs.get("num_return_sequences", 1) + if num_responses > 1: + logger.warning( + "Streaming is enabled, but the number of responses is set to %d. " + "Streaming is only supported for single response generation. " + "Setting the number of responses to 1.", + num_responses, + ) + generation_kwargs["num_return_sequences"] = 1 + # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming + generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words) + + # Prepare the prompt for the model + prepared_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True + ) + + # Avoid some unnecessary warnings in the generation pipeline call + generation_kwargs["pad_token_id"] = ( + generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id + ) + + # Generate responses + output = self.pipeline(prepared_prompt, **generation_kwargs) + replies = [o.get("generated_text", "") for o in output] + + # Remove stop words from replies if present + for stop_word in stop_words: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + + # Create ChatMessage instances for each reply + chat_messages = [ + self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs) + for r_index, reply in enumerate(replies) + ] + return {"replies": chat_messages} + + def create_message( + self, + text: str, + index: int, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt: str, + generation_kwargs: Dict[str, Any], + ) -> ChatMessage: + """ + Create a ChatMessage instance from the provided text, populated with metadata. + + :param text: The generated text. + :param index: The index of the generated text. + :param tokenizer: The tokenizer used for generation. + :param prompt: The prompt used for generation. + :param generation_kwargs: The generation parameters. + :return: A ChatMessage instance. + """ + completion_tokens = len(tokenizer.encode(text, add_special_tokens=False)) + prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) + total_tokens = prompt_token_count + completion_tokens + + # not the most sophisticated finish_reason detection, improve later to match + # https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format + finish_reason = ( + "length" if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize) else "stop" + ) + + meta = { + "finish_reason": finish_reason, + "index": index, + "model": self.huggingface_pipeline_kwargs["model"], + "usage": { + "completion_tokens": completion_tokens, + "prompt_tokens": prompt_token_count, + "total_tokens": total_tokens, + }, + } + + return ChatMessage.from_assistant(text, meta=meta) + + def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]: + """ + Validates the provided stop words. + + :param stop_words: A list of stop words to validate. + :return: A sanitized list of stop words or None if validation fails. + """ + if stop_words and not all(isinstance(word, str) for word in stop_words): + logger.warning( + "Invalid stop words provided. Stop words must be specified as a list of strings. " + "Ignoring stop words: %s", + stop_words, + ) + return None + + # deduplicate stop words + stop_words = list(set(stop_words or [])) + return stop_words diff --git a/haystack/components/generators/hf_utils.py b/haystack/components/generators/hf_utils.py index 93dd2d750..651ef0fc0 100644 --- a/haystack/components/generators/hf_utils.py +++ b/haystack/components/generators/hf_utils.py @@ -1,12 +1,15 @@ import inspect -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Callable +from haystack.dataclasses import StreamingChunk from haystack.lazy_imports import LazyImport with LazyImport(message="Run 'pip install transformers'") as transformers_import: from huggingface_hub import InferenceClient, HfApi from huggingface_hub.utils import RepositoryNotFoundError +PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"] + def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None): """ @@ -59,7 +62,9 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None: with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: import torch - from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast + from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer + + transformers_import.check() class StopWordsCriteria(StoppingCriteria): """ @@ -107,3 +112,19 @@ with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_tr len_stop_id = stop_id.size(0) result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id)) return result + + class HFTokenStreamingHandler(TextStreamer): + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + stream_handler: Callable[[StreamingChunk], None], + stop_words: Optional[List[str]] = None, + ): + super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore + self.token_handler = stream_handler + self.stop_words = stop_words or [] + + def on_finalized_text(self, word: str, stream_end: bool = False): + word_to_send = word + "\n" if stream_end else word + if word_to_send.strip() not in self.stop_words: + self.token_handler(StreamingChunk(content=word_to_send)) diff --git a/releasenotes/notes/add-hugging-face-chat-local-5fe7a88e24fde11b.yaml b/releasenotes/notes/add-hugging-face-chat-local-5fe7a88e24fde11b.yaml new file mode 100644 index 000000000..472281442 --- /dev/null +++ b/releasenotes/notes/add-hugging-face-chat-local-5fe7a88e24fde11b.yaml @@ -0,0 +1,19 @@ +--- +features: + - | + Introducing the HuggingFaceLocalChatGenerator, a new chat-based generator designed for leveraging chat models from + Hugging Face's (HF) model hub. Users can now perform inference with chat-based models in a local runtime, utilizing + familiar HF generation parameters, stop words, and even employing custom chat templates for custom message formatting. + This component also supports streaming responses and is optimized for compatibility with a variety of devices. + + Here is an example of how to use the HuggingFaceLocalChatGenerator: + + ```python + from haystack.components.generators.chat import HuggingFaceLocalChatGenerator + from haystack.dataclasses import ChatMessage + + generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta") + generator.warm_up() + messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")] + print(generator.run(messages)) + ``` diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py new file mode 100644 index 000000000..d69fc60b7 --- /dev/null +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -0,0 +1,207 @@ +from unittest.mock import patch, Mock + +import pytest +from transformers import PreTrainedTokenizer + +from haystack.components.generators.chat import HuggingFaceLocalChatGenerator +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.utils import ComponentDevice + + +# used to test serialization of streaming_callback +def streaming_callback_handler(x): + return x + + +@pytest.fixture +def model_info_mock(): + with patch( + "haystack.components.generators.chat.hugging_face_local.model_info", + new=Mock(return_value=Mock(pipeline_tag="text2text-generation")), + ) as mock: + yield mock + + +@pytest.fixture +def mock_pipeline_tokenizer(): + # Mocking the pipeline + mock_pipeline = Mock(return_value=[{"generated_text": "Berlin is cool"}]) + + # Mocking the tokenizer + mock_tokenizer = Mock(spec=PreTrainedTokenizer) + mock_tokenizer.encode.return_value = ["Berlin", "is", "cool"] + mock_pipeline.tokenizer = mock_tokenizer + + return mock_pipeline + + +class TestHuggingFaceLocalChatGenerator: + def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"n": 1} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceLocalChatGenerator( + model=model, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} + assert generator.streaming_callback == streaming_callback + + def test_init_custom_token(self): + generator = HuggingFaceLocalChatGenerator( + model="mistralai/Mistral-7B-Instruct-v0.2", + task="text2text-generation", + token="test-token", + device=ComponentDevice.from_str("cpu"), + ) + + assert generator.huggingface_pipeline_kwargs == { + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "task": "text2text-generation", + "token": "test-token", + "device": "cpu", + } + + def test_init_custom_device(self): + generator = HuggingFaceLocalChatGenerator( + model="mistralai/Mistral-7B-Instruct-v0.2", + task="text2text-generation", + device=ComponentDevice.from_str("cpu"), + ) + + assert generator.huggingface_pipeline_kwargs == { + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "task": "text2text-generation", + "token": None, + "device": "cpu", + } + + def test_init_task_parameter(self): + generator = HuggingFaceLocalChatGenerator(task="text2text-generation", device=ComponentDevice.from_str("cpu")) + + assert generator.huggingface_pipeline_kwargs == { + "model": "HuggingFaceH4/zephyr-7b-beta", + "task": "text2text-generation", + "token": None, + "device": "cpu", + } + + def test_init_task_in_huggingface_pipeline_kwargs(self): + generator = HuggingFaceLocalChatGenerator( + huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu") + ) + + assert generator.huggingface_pipeline_kwargs == { + "model": "HuggingFaceH4/zephyr-7b-beta", + "task": "text2text-generation", + "token": None, + "device": "cpu", + } + + def test_init_task_inferred_from_model_name(self, model_info_mock): + generator = HuggingFaceLocalChatGenerator( + model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu") + ) + + assert generator.huggingface_pipeline_kwargs == { + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "task": "text2text-generation", + "token": None, + "device": "cpu", + } + + def test_init_invalid_task(self): + with pytest.raises(ValueError, match="is not supported."): + HuggingFaceLocalChatGenerator(task="text-classification") + + def test_to_dict(self, model_info_mock): + generator = HuggingFaceLocalChatGenerator( + model="NousResearch/Llama-2-7b-chat-hf", + token="token", + generation_kwargs={"n": 5}, + stop_words=["stop", "words"], + streaming_callback=lambda x: x, + ) + + # Call the to_dict method + result = generator.to_dict() + init_params = result["init_parameters"] + + # Assert that the init_params dictionary contains the expected keys and values + assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" + assert "token" not in init_params["huggingface_pipeline_kwargs"] + assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} + + def test_from_dict(self, model_info_mock): + generator = HuggingFaceLocalChatGenerator( + model="NousResearch/Llama-2-7b-chat-hf", + generation_kwargs={"n": 5}, + stop_words=["stop", "words"], + streaming_callback=streaming_callback_handler, + ) + # Call the to_dict method + result = generator.to_dict() + + generator_2 = HuggingFaceLocalChatGenerator.from_dict(result) + + assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} + assert generator_2.streaming_callback is streaming_callback_handler + + @patch("haystack.components.generators.chat.hugging_face_local.pipeline") + def test_warm_up(self, pipeline_mock): + generator = HuggingFaceLocalChatGenerator( + model="mistralai/Mistral-7B-Instruct-v0.2", + task="text2text-generation", + device=ComponentDevice.from_str("cpu"), + ) + + pipeline_mock.assert_not_called() + + generator.warm_up() + + pipeline_mock.assert_called_once_with( + model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", token=None, device="cpu" + ) + + def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") + + # Use the mocked pipeline from the fixture and simulate warm_up + generator.pipeline = mock_pipeline_tokenizer + + results = generator.run(messages=chat_messages) + + assert "replies" in results + assert isinstance(results["replies"][0], ChatMessage) + chat_message = results["replies"][0] + assert chat_message.is_from(ChatRole.ASSISTANT) + assert chat_message.content == "Berlin is cool" + + def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") + + # Use the mocked pipeline from the fixture and simulate warm_up + generator.pipeline = mock_pipeline_tokenizer + + generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} + + # Use the mocked pipeline from the fixture and simulate warm_up + generator.pipeline = mock_pipeline_tokenizer + results = generator.run(messages=chat_messages, generation_kwargs=generation_kwargs) + + # check kwargs passed pipeline + _, kwargs = generator.pipeline.call_args + assert kwargs["max_new_tokens"] == 100 + assert kwargs["temperature"] == 0.8 + + # replies are properly parsed and returned + assert "replies" in results + assert isinstance(results["replies"][0], ChatMessage) + chat_message = results["replies"][0] + assert chat_message.is_from(ChatRole.ASSISTANT) + assert chat_message.content == "Berlin is cool"