mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
feat: Add HuggingFaceLocalChatGenerator (#6751)
This commit is contained in:
parent
8d65a8630b
commit
fea1428e84
@ -1,5 +1,13 @@
|
||||
from haystack.components.generators.chat.hugging_face_local import HuggingFaceLocalChatGenerator
|
||||
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator
|
||||
from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator
|
||||
|
||||
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator", "AzureOpenAIChatGenerator"]
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceLocalChatGenerator",
|
||||
"HuggingFaceTGIChatGenerator",
|
||||
"OpenAIChatGenerator",
|
||||
"GPTChatGenerator",
|
||||
"AzureOpenAIChatGenerator",
|
||||
]
|
||||
|
||||
327
haystack/components/generators/chat/hugging_face_local.py
Normal file
327
haystack/components/generators/chat/hugging_face_local.py
Normal file
@ -0,0 +1,327 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, Callable
|
||||
|
||||
from haystack.components.generators.hf_utils import PIPELINE_SUPPORTED_TASKS
|
||||
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.generators.hf_utils import HFTokenStreamingHandler
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
from huggingface_hub import model_info
|
||||
from transformers import StoppingCriteriaList, pipeline, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from haystack.components.generators.hf_utils import StopWordsCriteria # pylint: disable=ungrouped-imports
|
||||
from haystack.utils.hf import serialize_hf_model_kwargs, deserialize_hf_model_kwargs
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceLocalChatGenerator:
|
||||
"""
|
||||
|
||||
The `HuggingFaceLocalChatGenerator` class is a component designed for generating chat responses using models from
|
||||
Hugging Face's model hub. It is tailored for local runtime text generation tasks and provides a convenient interface
|
||||
for working with chat-based models, such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf`
|
||||
etc.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
|
||||
generator.warm_up()
|
||||
messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
|
||||
print(generator.run(messages))
|
||||
|
||||
# {'replies': [ChatMessage(content=' Natural Language Processing (NLP) is a subfield of artificial
|
||||
intelligence that deals with the interaction between computers and human language. It enables computers
|
||||
to understand, interpret, and generate human language in a valuable way. NLP involves various techniques
|
||||
such as speech recognition, text analysis, sentiment analysis, and machine translation. The ultimate goal
|
||||
is to make it easier for computers to process and derive meaning from human language, improving communication
|
||||
between humans and machines.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
|
||||
meta={'finish_reason': 'stop', 'index': 0, 'model': 'mistralai/Mistral-7B-Instruct-v0.2',
|
||||
'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})]}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "HuggingFaceH4/zephyr-7b-beta",
|
||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
):
|
||||
"""
|
||||
:param model: The name or path of a Hugging Face model for text generation,
|
||||
for example, mistralai/Mistral-7B-Instruct-v0.2,T TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ, etc.
|
||||
The important aspect of the model is that it should be a chat model and that it supports ChatML messaging
|
||||
format.
|
||||
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||
:param task: The task for the Hugging Face pipeline.
|
||||
Possible values are "text-generation" and "text2text-generation".
|
||||
Generally, decoder-only models like GPT support "text-generation",
|
||||
while encoder-decoder models like T5 support "text2text-generation".
|
||||
If the task is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||
If not specified, the component will attempt to infer the task from the model name,
|
||||
calling the Hugging Face Hub API.
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token: The token to use as HTTP bearer authorization for remote files.
|
||||
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
|
||||
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
|
||||
messages. While high-quality and well-supported chat models typically include their own chat templates
|
||||
accessible through their tokenizer, there are models that do not offer this feature. For such scenarios,
|
||||
or if you wish to use a custom template instead of the model's default, you can use this parameter to
|
||||
set your preferred chat template.
|
||||
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
|
||||
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
|
||||
See Hugging Face's documentation for more information:
|
||||
- https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation
|
||||
- https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
|
||||
- The only generation_kwargs we set by default is max_new_tokens, which is set to 512 tokens.
|
||||
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
|
||||
Hugging Face pipeline for text generation.
|
||||
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
|
||||
In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
|
||||
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
|
||||
for more information on the available kwargs.
|
||||
In this dictionary, you can also include `model_kwargs` to specify the kwargs
|
||||
for model initialization:
|
||||
https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained
|
||||
:param stop_words: A list of stop words. If any one of the stop words is generated, the generation is stopped.
|
||||
If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`.
|
||||
For some chat models, the output includes both the new text and the original prompt.
|
||||
In these cases, it's important to make sure your prompt has no stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
||||
generation_kwargs = generation_kwargs or {}
|
||||
|
||||
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
||||
# otherwise, populate them with values from other init parameters
|
||||
huggingface_pipeline_kwargs.setdefault("model", model)
|
||||
huggingface_pipeline_kwargs.setdefault("token", token)
|
||||
|
||||
device = ComponentDevice.resolve_device(device)
|
||||
device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
|
||||
|
||||
# task identification and validation
|
||||
if task is None:
|
||||
if "task" in huggingface_pipeline_kwargs:
|
||||
task = huggingface_pipeline_kwargs["task"]
|
||||
elif isinstance(huggingface_pipeline_kwargs["model"], str):
|
||||
task = model_info(
|
||||
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
|
||||
).pipeline_tag
|
||||
|
||||
if task not in PIPELINE_SUPPORTED_TASKS:
|
||||
raise ValueError(
|
||||
f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
|
||||
)
|
||||
huggingface_pipeline_kwargs["task"] = task
|
||||
|
||||
# if not specified, set return_full_text to False for text-generation
|
||||
# only generated text is returned (excluding prompt)
|
||||
if task == "text-generation":
|
||||
generation_kwargs.setdefault("return_full_text", False)
|
||||
|
||||
if stop_words and "stopping_criteria" in generation_kwargs:
|
||||
raise ValueError(
|
||||
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
|
||||
"Please specify only one of them."
|
||||
)
|
||||
generation_kwargs.setdefault("max_new_tokens", 512)
|
||||
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
||||
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
||||
|
||||
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.chat_template = chat_template
|
||||
self.streaming_callback = streaming_callback
|
||||
self.pipeline = None
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
|
||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||
|
||||
def warm_up(self):
|
||||
if self.pipeline is None:
|
||||
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||
serialization_dict = default_to_dict(
|
||||
self,
|
||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
)
|
||||
|
||||
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
||||
# we don't want to serialize valid tokens
|
||||
if isinstance(huggingface_pipeline_kwargs["token"], str):
|
||||
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
|
||||
|
||||
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
||||
return serialization_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
torch_and_transformers_import.check() # leave this, cls method
|
||||
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
|
||||
|
||||
huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
|
||||
deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Invoke text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
:param messages: A list of ChatMessage instances representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:return: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
tokenizer = self.pipeline.tokenizer
|
||||
|
||||
# Check and update generation parameters
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
||||
# pipeline call doesn't support stop_sequences, so we need to pop it
|
||||
stop_words = self._validate_stop_words(stop_words)
|
||||
|
||||
# Set up stop words criteria if stop words exist
|
||||
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
|
||||
if stop_words_criteria:
|
||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
||||
|
||||
if self.streaming_callback:
|
||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
logger.warning(
|
||||
"Streaming is enabled, but the number of responses is set to %d. "
|
||||
"Streaming is only supported for single response generation. "
|
||||
"Setting the number of responses to 1.",
|
||||
num_responses,
|
||||
)
|
||||
generation_kwargs["num_return_sequences"] = 1
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
|
||||
|
||||
# Prepare the prompt for the model
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
generation_kwargs["pad_token_id"] = (
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Generate responses
|
||||
output = self.pipeline(prepared_prompt, **generation_kwargs)
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
|
||||
# Remove stop words from replies if present
|
||||
for stop_word in stop_words:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
||||
|
||||
# Create ChatMessage instances for each reply
|
||||
chat_messages = [
|
||||
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs)
|
||||
for r_index, reply in enumerate(replies)
|
||||
]
|
||||
return {"replies": chat_messages}
|
||||
|
||||
def create_message(
|
||||
self,
|
||||
text: str,
|
||||
index: int,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prompt: str,
|
||||
generation_kwargs: Dict[str, Any],
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Create a ChatMessage instance from the provided text, populated with metadata.
|
||||
|
||||
:param text: The generated text.
|
||||
:param index: The index of the generated text.
|
||||
:param tokenizer: The tokenizer used for generation.
|
||||
:param prompt: The prompt used for generation.
|
||||
:param generation_kwargs: The generation parameters.
|
||||
:return: A ChatMessage instance.
|
||||
"""
|
||||
completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
|
||||
prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
total_tokens = prompt_token_count + completion_tokens
|
||||
|
||||
# not the most sophisticated finish_reason detection, improve later to match
|
||||
# https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format
|
||||
finish_reason = (
|
||||
"length" if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize) else "stop"
|
||||
)
|
||||
|
||||
meta = {
|
||||
"finish_reason": finish_reason,
|
||||
"index": index,
|
||||
"model": self.huggingface_pipeline_kwargs["model"],
|
||||
"usage": {
|
||||
"completion_tokens": completion_tokens,
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
return ChatMessage.from_assistant(text, meta=meta)
|
||||
|
||||
def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""
|
||||
Validates the provided stop words.
|
||||
|
||||
:param stop_words: A list of stop words to validate.
|
||||
:return: A sanitized list of stop words or None if validation fails.
|
||||
"""
|
||||
if stop_words and not all(isinstance(word, str) for word in stop_words):
|
||||
logger.warning(
|
||||
"Invalid stop words provided. Stop words must be specified as a list of strings. "
|
||||
"Ignoring stop words: %s",
|
||||
stop_words,
|
||||
)
|
||||
return None
|
||||
|
||||
# deduplicate stop words
|
||||
stop_words = list(set(stop_words or []))
|
||||
return stop_words
|
||||
@ -1,12 +1,15 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient, HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
|
||||
|
||||
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
|
||||
"""
|
||||
@ -59,7 +62,9 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None:
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
import torch
|
||||
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
|
||||
|
||||
transformers_import.check()
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
@ -107,3 +112,19 @@ with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_tr
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
|
||||
class HFTokenStreamingHandler(TextStreamer):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stream_handler: Callable[[StreamingChunk], None],
|
||||
stop_words: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
||||
self.token_handler = stream_handler
|
||||
self.stop_words = stop_words or []
|
||||
|
||||
def on_finalized_text(self, word: str, stream_end: bool = False):
|
||||
word_to_send = word + "\n" if stream_end else word
|
||||
if word_to_send.strip() not in self.stop_words:
|
||||
self.token_handler(StreamingChunk(content=word_to_send))
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introducing the HuggingFaceLocalChatGenerator, a new chat-based generator designed for leveraging chat models from
|
||||
Hugging Face's (HF) model hub. Users can now perform inference with chat-based models in a local runtime, utilizing
|
||||
familiar HF generation parameters, stop words, and even employing custom chat templates for custom message formatting.
|
||||
This component also supports streaming responses and is optimized for compatibility with a variety of devices.
|
||||
|
||||
Here is an example of how to use the HuggingFaceLocalChatGenerator:
|
||||
|
||||
```python
|
||||
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
|
||||
generator.warm_up()
|
||||
messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
|
||||
print(generator.run(messages))
|
||||
```
|
||||
207
test/components/generators/chat/test_hugging_face_local.py
Normal file
207
test/components/generators/chat/test_hugging_face_local.py
Normal file
@ -0,0 +1,207 @@
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
||||
from haystack.dataclasses import ChatMessage, ChatRole
|
||||
from haystack.utils import ComponentDevice
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
def streaming_callback_handler(x):
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_info_mock():
|
||||
with patch(
|
||||
"haystack.components.generators.chat.hugging_face_local.model_info",
|
||||
new=Mock(return_value=Mock(pipeline_tag="text2text-generation")),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline_tokenizer():
|
||||
# Mocking the pipeline
|
||||
mock_pipeline = Mock(return_value=[{"generated_text": "Berlin is cool"}])
|
||||
|
||||
# Mocking the tokenizer
|
||||
mock_tokenizer = Mock(spec=PreTrainedTokenizer)
|
||||
mock_tokenizer.encode.return_value = ["Berlin", "is", "cool"]
|
||||
mock_pipeline.tokenizer = mock_tokenizer
|
||||
|
||||
return mock_pipeline
|
||||
|
||||
|
||||
class TestHuggingFaceLocalChatGenerator:
|
||||
def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock):
|
||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
|
||||
def test_init_custom_token(self):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
token="test-token",
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"task": "text2text-generation",
|
||||
"token": "test-token",
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_custom_device(self):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_task_parameter(self):
|
||||
generator = HuggingFaceLocalChatGenerator(task="text2text-generation", device=ComponentDevice.from_str("cpu"))
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "HuggingFaceH4/zephyr-7b-beta",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_task_in_huggingface_pipeline_kwargs(self):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu")
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "HuggingFaceH4/zephyr-7b-beta",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_task_inferred_from_model_name(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu")
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"device": "cpu",
|
||||
}
|
||||
|
||||
def test_init_invalid_task(self):
|
||||
with pytest.raises(ValueError, match="is not supported."):
|
||||
HuggingFaceLocalChatGenerator(task="text-classification")
|
||||
|
||||
def test_to_dict(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="NousResearch/Llama-2-7b-chat-hf",
|
||||
token="token",
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=lambda x: x,
|
||||
)
|
||||
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
init_params = result["init_parameters"]
|
||||
|
||||
# Assert that the init_params dictionary contains the expected keys and values
|
||||
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
||||
assert "token" not in init_params["huggingface_pipeline_kwargs"]
|
||||
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
def test_from_dict(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="NousResearch/Llama-2-7b-chat-hf",
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=streaming_callback_handler,
|
||||
)
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
|
||||
generator_2 = HuggingFaceLocalChatGenerator.from_dict(result)
|
||||
|
||||
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
||||
assert generator_2.streaming_callback is streaming_callback_handler
|
||||
|
||||
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
|
||||
def test_warm_up(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
)
|
||||
|
||||
pipeline_mock.assert_not_called()
|
||||
|
||||
generator.warm_up()
|
||||
|
||||
pipeline_mock.assert_called_once_with(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", token=None, device="cpu"
|
||||
)
|
||||
|
||||
def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
|
||||
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
|
||||
# Use the mocked pipeline from the fixture and simulate warm_up
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
|
||||
results = generator.run(messages=chat_messages)
|
||||
|
||||
assert "replies" in results
|
||||
assert isinstance(results["replies"][0], ChatMessage)
|
||||
chat_message = results["replies"][0]
|
||||
assert chat_message.is_from(ChatRole.ASSISTANT)
|
||||
assert chat_message.content == "Berlin is cool"
|
||||
|
||||
def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
|
||||
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
|
||||
# Use the mocked pipeline from the fixture and simulate warm_up
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
|
||||
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
|
||||
|
||||
# Use the mocked pipeline from the fixture and simulate warm_up
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
results = generator.run(messages=chat_messages, generation_kwargs=generation_kwargs)
|
||||
|
||||
# check kwargs passed pipeline
|
||||
_, kwargs = generator.pipeline.call_args
|
||||
assert kwargs["max_new_tokens"] == 100
|
||||
assert kwargs["temperature"] == 0.8
|
||||
|
||||
# replies are properly parsed and returned
|
||||
assert "replies" in results
|
||||
assert isinstance(results["replies"][0], ChatMessage)
|
||||
chat_message = results["replies"][0]
|
||||
assert chat_message.is_from(ChatRole.ASSISTANT)
|
||||
assert chat_message.content == "Berlin is cool"
|
||||
Loading…
x
Reference in New Issue
Block a user