mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
feat: Add HuggingFaceTGIChatGenerator Haystack 2.x component (#6199)
* Add ChatHuggingFaceTGIGenerator * Add release note --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
parent
03015877f3
commit
d7e1833c40
@ -23,7 +23,7 @@ First, run the minimal Haystack installation:
|
||||
pip install farm-haystack
|
||||
```
|
||||
|
||||
Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data:
|
||||
Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data:
|
||||
|
||||
```python
|
||||
from haystack.document_stores import InMemoryDocumentStore
|
||||
|
||||
3
haystack/preview/components/generators/chat/__init__.py
Normal file
3
haystack/preview/components/generators/chat/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from haystack.preview.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
|
||||
|
||||
__all__ = ["HuggingFaceTGIChatGenerator"]
|
||||
276
haystack/preview/components/generators/chat/hugging_face_tgi.py
Normal file
276
haystack/preview/components/generators/chat/hugging_face_tgi.py
Normal file
@ -0,0 +1,276 @@
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Iterable, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params
|
||||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.preview.dataclasses import ChatMessage, StreamingChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceTGIChatGenerator:
|
||||
"""
|
||||
Enables text generation using HuggingFace Hub hosted chat-based LLMs. This component is designed to seamlessly
|
||||
inference chat-based models deployed on the Text Generation Inference (TGI) backend.
|
||||
|
||||
You can use this component for chat LLMs hosted on Hugging Face inference endpoints, the rate-limited
|
||||
Inference API tier:
|
||||
|
||||
```python
|
||||
from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||
from haystack.preview.dataclasses import ChatMessage
|
||||
|
||||
messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
|
||||
ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
|
||||
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="<your-token>")
|
||||
client.warm_up()
|
||||
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
|
||||
print(response)
|
||||
```
|
||||
|
||||
For chat LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint and/or your own custom TGI
|
||||
endpoint, you'll need to provide the URL of the endpoint as well as a valid token:
|
||||
|
||||
```python
|
||||
from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||
from haystack.preview.dataclasses import ChatMessage
|
||||
|
||||
messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
|
||||
ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf",
|
||||
url="<your-tgi-endpoint-url>",
|
||||
token="<your-token>")
|
||||
client.warm_up()
|
||||
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
|
||||
print(response)
|
||||
```
|
||||
|
||||
Key Features and Compatibility:
|
||||
- **Primary Compatibility**: Designed to work seamlessly with any chat-based model deployed using the TGI
|
||||
framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference.
|
||||
- **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face
|
||||
inference endpoints. For more details, refer to https://huggingface.co/inference-endpoints.
|
||||
- **Inference API Support**: Supports inference of TGI chat LLMs hosted on the rate-limited Inference
|
||||
API tier. Learn more about the Inference API at https://huggingface.co/inference-api.
|
||||
Discover available chat models using the following command:
|
||||
```
|
||||
wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat
|
||||
```
|
||||
and simply use the model ID as the model parameter for this component. You'll also need to provide a valid
|
||||
Hugging Face API token as the token parameter.
|
||||
- **Custom TGI Endpoints**: Supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can
|
||||
deploy their own TGI endpoint using the TGI framework. For more details, refer
|
||||
to https://huggingface.co/inference-endpoints.
|
||||
|
||||
Input and Output Format:
|
||||
- **ChatMessage Format**: This component uses the ChatMessage format to structure both input and output,
|
||||
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
|
||||
ChatMessage format can be found at https://github.com/openai/openai-python/blob/main/chatml.md.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "meta-llama/Llama-2-13b-chat-hf",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceTGIChatGenerator instance.
|
||||
|
||||
:param model: A string representing the model path or URL. Default is "meta-llama/Llama-2-13b-chat-hf".
|
||||
:param url: An optional string representing the URL of the TGI endpoint.
|
||||
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
|
||||
messages. While high-quality and well-supported chat models typically include their own chat templates
|
||||
accessible through their tokenizer, there are models that do not offer this feature. For such scenarios,
|
||||
or if you wish to use a custom template instead of the model's default, you can use this parameter to
|
||||
set your preferred chat template.
|
||||
:param token: The Hugging Face token for HTTP bearer authorization.
|
||||
You can find your HF token at https://huggingface.co/settings/tokens.
|
||||
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
|
||||
Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
|
||||
See Hugging Face's [documentation](https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters)
|
||||
for more information.
|
||||
:param stop_words: An optional list of strings representing the stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, token)
|
||||
|
||||
# handle generation kwargs setup
|
||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||
check_generation_params(generation_kwargs, ["n"])
|
||||
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
||||
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.chat_template = chat_template
|
||||
self.token = token
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.client = InferenceClient(url or model, token=token)
|
||||
self.streaming_callback = streaming_callback
|
||||
self.tokenizer = None
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
Load the tokenizer.
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
|
||||
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
|
||||
chat_template = getattr(self.tokenizer, "chat_template", None)
|
||||
if not chat_template and not self.chat_template:
|
||||
logger.warning(
|
||||
"The model '%s' doesn't have a default chat_template, and no chat_template was supplied during "
|
||||
"this component's initialization. It’s possible that the model doesn't support ChatML inference "
|
||||
"format, potentially leading to unexpected behavior.",
|
||||
self.model,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
|
||||
:return: A dictionary containing the serialized component.
|
||||
"""
|
||||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
chat_template=self.chat_template,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIChatGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
# Don't send URL as it is sensitive information
|
||||
return {"model": self.model}
|
||||
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Invoke the text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
:param messages: A list of ChatMessage instances representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:return: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
|
||||
# check generation kwargs given as parameters to override the default ones
|
||||
additional_params = ["n", "stop_words"]
|
||||
check_generation_params(generation_kwargs, additional_params)
|
||||
|
||||
# update generation kwargs by merging with the default ones
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
num_responses = generation_kwargs.pop("n", 1)
|
||||
|
||||
# merge stop_words and stop_sequences into a single list
|
||||
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
||||
generation_kwargs["stop_sequences"].extend(generation_kwargs.pop("stop_words", []))
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise RuntimeError("Please call warm_up() before running LLM inference.")
|
||||
|
||||
# apply either model's chat template or the user-provided one
|
||||
prepared_prompt: str = self.tokenizer.apply_chat_template(
|
||||
conversation=messages, chat_template=self.chat_template, tokenize=False
|
||||
)
|
||||
prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False))
|
||||
|
||||
if self.streaming_callback:
|
||||
if num_responses > 1:
|
||||
raise ValueError("Cannot stream multiple responses, please set n=1.")
|
||||
|
||||
return self._run_streaming(prepared_prompt, prompt_token_count, generation_kwargs)
|
||||
|
||||
return self._run_non_streaming(prepared_prompt, prompt_token_count, num_responses, generation_kwargs)
|
||||
|
||||
def _run_streaming(
|
||||
self, prepared_prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
res: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
|
||||
prepared_prompt, stream=True, details=True, **generation_kwargs
|
||||
)
|
||||
chunk = None
|
||||
# pylint: disable=not-an-iterable
|
||||
for chunk in res:
|
||||
token: Token = chunk.token
|
||||
if token.special:
|
||||
continue
|
||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||
stream_chunk = StreamingChunk(token.text, chunk_metadata)
|
||||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
|
||||
message = ChatMessage.from_assistant(chunk.generated_text)
|
||||
message.metadata.update(
|
||||
{
|
||||
"finish_reason": chunk.details.finish_reason.value,
|
||||
"index": 0,
|
||||
"model": self.client.model,
|
||||
"usage": {
|
||||
"completion_tokens": chunk.details.generated_tokens,
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"total_tokens": prompt_token_count + chunk.details.generated_tokens,
|
||||
},
|
||||
}
|
||||
)
|
||||
return {"replies": [message]}
|
||||
|
||||
def _run_non_streaming(
|
||||
self, prepared_prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any]
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
chat_messages: List[ChatMessage] = []
|
||||
for _i in range(num_responses):
|
||||
tgr: TextGenerationResponse = self.client.text_generation(
|
||||
prepared_prompt, details=True, **generation_kwargs
|
||||
)
|
||||
message = ChatMessage.from_assistant(tgr.generated_text)
|
||||
message.metadata.update(
|
||||
{
|
||||
"finish_reason": tgr.details.finish_reason.value,
|
||||
"index": _i,
|
||||
"model": self.client.model,
|
||||
"usage": {
|
||||
"completion_tokens": len(tgr.details.tokens),
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"total_tokens": prompt_token_count + len(tgr.details.tokens),
|
||||
},
|
||||
}
|
||||
)
|
||||
chat_messages.append(message)
|
||||
return {"replies": chat_messages}
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Adds `HuggingFaceTGIChatGenerator` for text and chat generation. This components support remote inferencing for
|
||||
Hugging Face LLMs via text-generation-inference (TGI) protocol.
|
||||
0
test/preview/components/generators/chat/__init__.py
Normal file
0
test/preview/components/generators/chat/__init__.py
Normal file
11
test/preview/components/generators/chat/conftest.py
Normal file
11
test/preview/components/generators/chat/conftest.py
Normal file
@ -0,0 +1,11 @@
|
||||
import pytest
|
||||
|
||||
from haystack.preview.dataclasses import ChatMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_messages():
|
||||
return [
|
||||
ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"),
|
||||
ChatMessage.from_user("Tell me about Berlin"),
|
||||
]
|
||||
317
test/preview/components/generators/chat/test_hugging_face_tgi.py
Normal file
317
test/preview/components/generators/chat/test_hugging_face_tgi.py
Normal file
@ -0,0 +1,317 @@
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||
|
||||
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.preview.components.generators.chat.hugging_face_tgi.check_valid_model", MagicMock(return_value=None)
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_generation():
|
||||
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
|
||||
mock_response = Mock()
|
||||
mock_response.generated_text = "I'm fine, thanks."
|
||||
details = Mock()
|
||||
details.finish_reason = MagicMock(field1="value")
|
||||
details.tokens = [1, 2, 3]
|
||||
mock_response.details = details
|
||||
mock_text_generation.return_value = mock_response
|
||||
yield mock_text_generation
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
def streaming_callback_handler(x):
|
||||
return x
|
||||
|
||||
|
||||
class TestHuggingFaceTGIChatGenerator:
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer):
|
||||
model = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceTGIChatGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
generator.warm_up()
|
||||
|
||||
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
|
||||
assert generator.tokenizer is not None
|
||||
assert generator.client is not None
|
||||
assert generator.streaming_callback == streaming_callback
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
# Initialize the HuggingFaceTGIChatGenerator object with valid parameters
|
||||
generator = HuggingFaceTGIChatGenerator(
|
||||
model="NousResearch/Llama-2-7b-chat-hf",
|
||||
token="token",
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=lambda x: x,
|
||||
)
|
||||
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
init_params = result["init_parameters"]
|
||||
|
||||
# Assert that the init_params dictionary contains the expected keys and values
|
||||
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
||||
assert init_params["token"] is None
|
||||
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
generator = HuggingFaceTGIChatGenerator(
|
||||
model="NousResearch/Llama-2-7b-chat-hf",
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=streaming_callback_handler,
|
||||
)
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
|
||||
generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
|
||||
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
|
||||
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
assert generator_2.streaming_callback is streaming_callback_handler
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer):
|
||||
generator = HuggingFaceTGIChatGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
# Assert that the tokenizer is now initialized
|
||||
assert generator.tokenizer is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog):
|
||||
generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
|
||||
# Set chat_template to None for this specific test
|
||||
mock_auto_tokenizer.chat_template = None
|
||||
generator.warm_up()
|
||||
|
||||
# warning message should be logged
|
||||
assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_chat_template(
|
||||
self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
|
||||
):
|
||||
custom_chat_template = "Here goes some Jinja template"
|
||||
|
||||
# mocked method to check if we called apply_chat_template with the custom template
|
||||
mock_auto_tokenizer.apply_chat_template = MagicMock(return_value="some_value")
|
||||
|
||||
generator = HuggingFaceTGIChatGenerator(chat_template=custom_chat_template)
|
||||
generator.warm_up()
|
||||
|
||||
assert generator.chat_template == custom_chat_template
|
||||
|
||||
generator.run(messages=chat_messages)
|
||||
assert mock_auto_tokenizer.apply_chat_template.call_count == 1
|
||||
|
||||
# and we indeed called apply_chat_template with the custom template
|
||||
_, kwargs = mock_auto_tokenizer.apply_chat_template.call_args
|
||||
assert kwargs["chat_template"] == custom_chat_template
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_invalid_model_path_or_url(self, mock_check_valid_model):
|
||||
model = "invalid_model"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
mock_check_valid_model.side_effect = ValueError("Invalid model path or url")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceTGIChatGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceTGIChatGenerator(model="NousResearch/Llama-2-7b-chat-hf", url="invalid_url")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
|
||||
# When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
):
|
||||
model = "meta-llama/Llama-2-13b-chat-hf"
|
||||
generation_kwargs = {"n": 1}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceTGIChatGenerator(
|
||||
model=model,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
generator.warm_up()
|
||||
|
||||
response = generator.run(messages=chat_messages)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
# note how n because it is not text generation parameter was not passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
):
|
||||
model = "meta-llama/Llama-2-13b-chat-hf"
|
||||
token = None
|
||||
generation_kwargs = {"n": 3}
|
||||
stop_words = ["stop"]
|
||||
streaming_callback = None
|
||||
|
||||
generator = HuggingFaceTGIChatGenerator(
|
||||
model=model,
|
||||
token=token,
|
||||
generation_kwargs=generation_kwargs,
|
||||
stop_words=stop_words,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
generator.warm_up()
|
||||
|
||||
response = generator.run(chat_messages)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
|
||||
|
||||
# note how n caused n replies to be generated
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 3
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_with_stop_words(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
):
|
||||
generator = HuggingFaceTGIChatGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
stop_words = ["stop", "words"]
|
||||
|
||||
# Generate text response with stop words
|
||||
response = generator.run(chat_messages, generation_kwargs={"stop_words": stop_words})
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
# we translate stop_words to stop_sequences
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_with_custom_generation_parameters(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
):
|
||||
# Create an instance of HuggingFaceRemoteGenerator with no generation parameters
|
||||
generator = HuggingFaceTGIChatGenerator()
|
||||
generator.warm_up()
|
||||
|
||||
# but then we pass them in run
|
||||
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
|
||||
response = generator.run(chat_messages, generation_kwargs=generation_kwargs)
|
||||
|
||||
# again check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
|
||||
|
||||
# Assert that the response contains the generated replies and the right response
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
assert response["replies"][0].content == "I'm fine, thanks."
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_text_with_streaming_callback(
|
||||
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
|
||||
):
|
||||
streaming_call_count = 0
|
||||
|
||||
# Define the streaming callback function
|
||||
def streaming_callback_fn(chunk: StreamingChunk):
|
||||
nonlocal streaming_call_count
|
||||
streaming_call_count += 1
|
||||
assert isinstance(chunk, StreamingChunk)
|
||||
|
||||
# Create an instance of HuggingFaceRemoteGenerator
|
||||
generator = HuggingFaceTGIChatGenerator(streaming_callback=streaming_callback_fn)
|
||||
generator.warm_up()
|
||||
|
||||
# Create a fake streamed response
|
||||
# self needed here, don't remove
|
||||
def mock_iter(self):
|
||||
yield TextGenerationStreamResponse(
|
||||
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
|
||||
)
|
||||
yield TextGenerationStreamResponse(
|
||||
generated_text=None,
|
||||
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
|
||||
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
|
||||
)
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
mock_text_generation.return_value = mock_response
|
||||
|
||||
# Generate text response with streaming callback
|
||||
response = generator.run(chat_messages)
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
21
test/preview/components/generators/conftest.py
Normal file
21
test/preview/components/generators/conftest.py
Normal file
@ -0,0 +1,21 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auto_tokenizer():
|
||||
"""
|
||||
In the original mock_auto_tokenizer fixture, we were mocking the transformers.AutoTokenizer.from_pretrained
|
||||
method directly, but we were not providing a return value for this method. Therefore, when from_pretrained
|
||||
was called within HuggingFaceTGIChatGenerator, it returned None because that's the default behavior of a
|
||||
MagicMock object when a return value isn't specified.
|
||||
|
||||
We will update the mock_auto_tokenizer fixture to return a MagicMock object when from_pretrained is called
|
||||
in another PR. For now, we will use this fixture to mock the AutoTokenizer.from_pretrained method.
|
||||
"""
|
||||
|
||||
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_from_pretrained.return_value = mock_tokenizer
|
||||
yield mock_tokenizer
|
||||
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.preview.dataclasses import ChatMessage, ChatRole
|
||||
|
||||
@ -29,19 +30,41 @@ def test_from_system_with_valid_content():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_empty_content():
|
||||
message = ChatMessage("", ChatRole.USER, None)
|
||||
message = ChatMessage.from_user("")
|
||||
assert message.content == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_invalid_role():
|
||||
with pytest.raises(TypeError):
|
||||
ChatMessage("Invalid role", "invalid_role")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_function_with_empty_name():
|
||||
content = "Function call"
|
||||
message = ChatMessage.from_function(content, "")
|
||||
assert message.content == content
|
||||
assert message.name == ""
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_chat_templating_on_chat_message():
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_custom_chat_templating_on_chat_message():
|
||||
anthropic_template = (
|
||||
"{%- for message in messages %}"
|
||||
"{%- if message.role == 'user' %}\n\nHuman: {{ message.content.strip() }}"
|
||||
"{%- elif message.role == 'assistant' %}\n\nAssistant: {{ message.content.strip() }}"
|
||||
"{%- elif message.role == 'function' %}{{ raise('anthropic does not support function calls.') }}"
|
||||
"{%- elif message.role == 'system' and loop.index == 1 %}{{ message.content }}"
|
||||
"{%- else %}{{ raise('Invalid message role: ' + message.role) }}"
|
||||
"{%- endif %}"
|
||||
"{%- endfor %}"
|
||||
"\n\nAssistant:"
|
||||
)
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
# could be any tokenizer, let's use the one we already likely have in cache
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False)
|
||||
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user