feat: Add HuggingFaceLocalChatGenerator (#6751)

This commit is contained in:
Vladimir Blagojevic 2024-01-18 15:53:12 +01:00 committed by GitHub
parent 8d65a8630b
commit fea1428e84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 585 additions and 3 deletions

View File

@ -1,5 +1,13 @@
from haystack.components.generators.chat.hugging_face_local import HuggingFaceLocalChatGenerator
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator
from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator", "AzureOpenAIChatGenerator"]
__all__ = [
"HuggingFaceLocalChatGenerator",
"HuggingFaceTGIChatGenerator",
"OpenAIChatGenerator",
"GPTChatGenerator",
"AzureOpenAIChatGenerator",
]

View File

@ -0,0 +1,327 @@
import logging
import sys
from typing import Any, Dict, List, Literal, Optional, Union, Callable
from haystack.components.generators.hf_utils import PIPELINE_SUPPORTED_TASKS
from haystack import component, default_to_dict, default_from_dict
from haystack.components.generators.hf_utils import HFTokenStreamingHandler
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
from huggingface_hub import model_info
from transformers import StoppingCriteriaList, pipeline, PreTrainedTokenizer, PreTrainedTokenizerFast
from haystack.components.generators.hf_utils import StopWordsCriteria # pylint: disable=ungrouped-imports
from haystack.utils.hf import serialize_hf_model_kwargs, deserialize_hf_model_kwargs
@component
class HuggingFaceLocalChatGenerator:
"""
The `HuggingFaceLocalChatGenerator` class is a component designed for generating chat responses using models from
Hugging Face's model hub. It is tailored for local runtime text generation tasks and provides a convenient interface
for working with chat-based models, such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf`
etc.
Usage example:
```python
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
from haystack.dataclasses import ChatMessage
generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
generator.warm_up()
messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
print(generator.run(messages))
# {'replies': [ChatMessage(content=' Natural Language Processing (NLP) is a subfield of artificial
intelligence that deals with the interaction between computers and human language. It enables computers
to understand, interpret, and generate human language in a valuable way. NLP involves various techniques
such as speech recognition, text analysis, sentiment analysis, and machine translation. The ultimate goal
is to make it easier for computers to process and derive meaning from human language, improving communication
between humans and machines.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
meta={'finish_reason': 'stop', 'index': 0, 'model': 'mistralai/Mistral-7B-Instruct-v0.2',
'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})]}
```
"""
def __init__(
self,
model: str = "HuggingFaceH4/zephyr-7b-beta",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None,
chat_template: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
:param model: The name or path of a Hugging Face model for text generation,
for example, mistralai/Mistral-7B-Instruct-v0.2,T TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ, etc.
The important aspect of the model is that it should be a chat model and that it supports ChatML messaging
format.
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param task: The task for the Hugging Face pipeline.
Possible values are "text-generation" and "text2text-generation".
Generally, decoder-only models like GPT support "text-generation",
while encoder-decoder models like T5 support "text2text-generation".
If the task is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
If not specified, the component will attempt to infer the task from the model name,
calling the Hugging Face Hub API.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
messages. While high-quality and well-supported chat models typically include their own chat templates
accessible through their tokenizer, there are models that do not offer this feature. For such scenarios,
or if you wish to use a custom template instead of the model's default, you can use this parameter to
set your preferred chat template.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
See Hugging Face's documentation for more information:
- https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation
- https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
- The only generation_kwargs we set by default is max_new_tokens, which is set to 512 tokens.
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text generation.
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
for more information on the available kwargs.
In this dictionary, you can also include `model_kwargs` to specify the kwargs
for model initialization:
https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained
:param stop_words: A list of stop words. If any one of the stop words is generated, the generation is stopped.
If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`.
For some chat models, the output includes both the new text and the original prompt.
In these cases, it's important to make sure your prompt has no stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
torch_and_transformers_import.check()
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {}
# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token)
device = ComponentDevice.resolve_device(device)
device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
# task identification and validation
if task is None:
if "task" in huggingface_pipeline_kwargs:
task = huggingface_pipeline_kwargs["task"]
elif isinstance(huggingface_pipeline_kwargs["model"], str):
task = model_info(
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
).pipeline_tag
if task not in PIPELINE_SUPPORTED_TASKS:
raise ValueError(
f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
)
huggingface_pipeline_kwargs["task"] = task
# if not specified, set return_full_text to False for text-generation
# only generated text is returned (excluding prompt)
if task == "text-generation":
generation_kwargs.setdefault("return_full_text", False)
if stop_words and "stopping_criteria" in generation_kwargs:
raise ValueError(
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
"Please specify only one of them."
)
generation_kwargs.setdefault("max_new_tokens", 512)
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs
self.chat_template = chat_template
self.streaming_callback = streaming_callback
self.pipeline = None
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
return {"model": self.huggingface_pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
def warm_up(self):
if self.pipeline is None:
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
serialization_dict = default_to_dict(
self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
# we don't want to serialize valid tokens
if isinstance(huggingface_pipeline_kwargs["token"], str):
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
"""
Deserialize this component from a dictionary.
"""
torch_and_transformers_import.check() # leave this, cls method
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return default_from_dict(cls, data)
@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:return: A list containing the generated responses as ChatMessage instances.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
tokenizer = self.pipeline.tokenizer
# Check and update generation parameters
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
# pipeline call doesn't support stop_sequences, so we need to pop it
stop_words = self._validate_stop_words(stop_words)
# Set up stop words criteria if stop words exist
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
if self.streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
logger.warning(
"Streaming is enabled, but the number of responses is set to %d. "
"Streaming is only supported for single response generation. "
"Setting the number of responses to 1.",
num_responses,
)
generation_kwargs["num_return_sequences"] = 1
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
# Prepare the prompt for the model
prepared_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# Generate responses
output = self.pipeline(prepared_prompt, **generation_kwargs)
replies = [o.get("generated_text", "") for o in output]
# Remove stop words from replies if present
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
# Create ChatMessage instances for each reply
chat_messages = [
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs)
for r_index, reply in enumerate(replies)
]
return {"replies": chat_messages}
def create_message(
self,
text: str,
index: int,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
generation_kwargs: Dict[str, Any],
) -> ChatMessage:
"""
Create a ChatMessage instance from the provided text, populated with metadata.
:param text: The generated text.
:param index: The index of the generated text.
:param tokenizer: The tokenizer used for generation.
:param prompt: The prompt used for generation.
:param generation_kwargs: The generation parameters.
:return: A ChatMessage instance.
"""
completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
total_tokens = prompt_token_count + completion_tokens
# not the most sophisticated finish_reason detection, improve later to match
# https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format
finish_reason = (
"length" if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize) else "stop"
)
meta = {
"finish_reason": finish_reason,
"index": index,
"model": self.huggingface_pipeline_kwargs["model"],
"usage": {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_token_count,
"total_tokens": total_tokens,
},
}
return ChatMessage.from_assistant(text, meta=meta)
def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
"""
Validates the provided stop words.
:param stop_words: A list of stop words to validate.
:return: A sanitized list of stop words or None if validation fails.
"""
if stop_words and not all(isinstance(word, str) for word in stop_words):
logger.warning(
"Invalid stop words provided. Stop words must be specified as a list of strings. "
"Ignoring stop words: %s",
stop_words,
)
return None
# deduplicate stop words
stop_words = list(set(stop_words or []))
return stop_words

View File

@ -1,12 +1,15 @@
import inspect
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Callable
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import RepositoryNotFoundError
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
"""
@ -59,7 +62,9 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None:
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
import torch
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
transformers_import.check()
class StopWordsCriteria(StoppingCriteria):
"""
@ -107,3 +112,19 @@ with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_tr
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result
class HFTokenStreamingHandler(TextStreamer):
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], None],
stop_words: Optional[List[str]] = None,
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
self.stop_words = stop_words or []
def on_finalized_text(self, word: str, stream_end: bool = False):
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
self.token_handler(StreamingChunk(content=word_to_send))

View File

@ -0,0 +1,19 @@
---
features:
- |
Introducing the HuggingFaceLocalChatGenerator, a new chat-based generator designed for leveraging chat models from
Hugging Face's (HF) model hub. Users can now perform inference with chat-based models in a local runtime, utilizing
familiar HF generation parameters, stop words, and even employing custom chat templates for custom message formatting.
This component also supports streaming responses and is optimized for compatibility with a variety of devices.
Here is an example of how to use the HuggingFaceLocalChatGenerator:
```python
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
from haystack.dataclasses import ChatMessage
generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
generator.warm_up()
messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
print(generator.run(messages))
```

View File

@ -0,0 +1,207 @@
from unittest.mock import patch, Mock
import pytest
from transformers import PreTrainedTokenizer
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.utils import ComponentDevice
# used to test serialization of streaming_callback
def streaming_callback_handler(x):
return x
@pytest.fixture
def model_info_mock():
with patch(
"haystack.components.generators.chat.hugging_face_local.model_info",
new=Mock(return_value=Mock(pipeline_tag="text2text-generation")),
) as mock:
yield mock
@pytest.fixture
def mock_pipeline_tokenizer():
# Mocking the pipeline
mock_pipeline = Mock(return_value=[{"generated_text": "Berlin is cool"}])
# Mocking the tokenizer
mock_tokenizer = Mock(spec=PreTrainedTokenizer)
mock_tokenizer.encode.return_value = ["Berlin", "is", "cool"]
mock_pipeline.tokenizer = mock_tokenizer
return mock_pipeline
class TestHuggingFaceLocalChatGenerator:
def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceLocalChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.streaming_callback == streaming_callback
def test_init_custom_token(self):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
token="test-token",
device=ComponentDevice.from_str("cpu"),
)
assert generator.huggingface_pipeline_kwargs == {
"model": "mistralai/Mistral-7B-Instruct-v0.2",
"task": "text2text-generation",
"token": "test-token",
"device": "cpu",
}
def test_init_custom_device(self):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
device=ComponentDevice.from_str("cpu"),
)
assert generator.huggingface_pipeline_kwargs == {
"model": "mistralai/Mistral-7B-Instruct-v0.2",
"task": "text2text-generation",
"token": None,
"device": "cpu",
}
def test_init_task_parameter(self):
generator = HuggingFaceLocalChatGenerator(task="text2text-generation", device=ComponentDevice.from_str("cpu"))
assert generator.huggingface_pipeline_kwargs == {
"model": "HuggingFaceH4/zephyr-7b-beta",
"task": "text2text-generation",
"token": None,
"device": "cpu",
}
def test_init_task_in_huggingface_pipeline_kwargs(self):
generator = HuggingFaceLocalChatGenerator(
huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu")
)
assert generator.huggingface_pipeline_kwargs == {
"model": "HuggingFaceH4/zephyr-7b-beta",
"task": "text2text-generation",
"token": None,
"device": "cpu",
}
def test_init_task_inferred_from_model_name(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu")
)
assert generator.huggingface_pipeline_kwargs == {
"model": "mistralai/Mistral-7B-Instruct-v0.2",
"task": "text2text-generation",
"token": None,
"device": "cpu",
}
def test_init_invalid_task(self):
with pytest.raises(ValueError, match="is not supported."):
HuggingFaceLocalChatGenerator(task="text-classification")
def test_to_dict(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
token="token",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
)
# Call the to_dict method
result = generator.to_dict()
init_params = result["init_parameters"]
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert "token" not in init_params["huggingface_pipeline_kwargs"]
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
def test_from_dict(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
)
# Call the to_dict method
result = generator.to_dict()
generator_2 = HuggingFaceLocalChatGenerator.from_dict(result)
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
device=ComponentDevice.from_str("cpu"),
)
pipeline_mock.assert_not_called()
generator.warm_up()
pipeline_mock.assert_called_once_with(
model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", token=None, device="cpu"
)
def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
results = generator.run(messages=chat_messages)
assert "replies" in results
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.content == "Berlin is cool"
def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
# Use the mocked pipeline from the fixture and simulate warm_up
generator.pipeline = mock_pipeline_tokenizer
results = generator.run(messages=chat_messages, generation_kwargs=generation_kwargs)
# check kwargs passed pipeline
_, kwargs = generator.pipeline.call_args
assert kwargs["max_new_tokens"] == 100
assert kwargs["temperature"] == 0.8
# replies are properly parsed and returned
assert "replies" in results
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.content == "Berlin is cool"