feat: Add HuggingFaceTGIChatGenerator Haystack 2.x component (#6199)

* Add ChatHuggingFaceTGIGenerator

* Add release note
---------
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
Vladimir Blagojevic 2023-11-06 09:48:45 +01:00 committed by GitHub
parent 03015877f3
commit d7e1833c40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 664 additions and 8 deletions

View File

@ -23,7 +23,7 @@ First, run the minimal Haystack installation:
pip install farm-haystack
```
Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data:
Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data:
```python
from haystack.document_stores import InMemoryDocumentStore

View File

@ -0,0 +1,3 @@
from haystack.preview.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
__all__ = ["HuggingFaceTGIChatGenerator"]

View File

@ -0,0 +1,276 @@
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Iterable, Callable
from urllib.parse import urlparse
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
from transformers import AutoTokenizer
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.preview.dataclasses import ChatMessage, StreamingChunk
logger = logging.getLogger(__name__)
class HuggingFaceTGIChatGenerator:
"""
Enables text generation using HuggingFace Hub hosted chat-based LLMs. This component is designed to seamlessly
inference chat-based models deployed on the Text Generation Inference (TGI) backend.
You can use this component for chat LLMs hosted on Hugging Face inference endpoints, the rate-limited
Inference API tier:
```python
from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.preview.dataclasses import ChatMessage
messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="<your-token>")
client.warm_up()
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
print(response)
```
For chat LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint and/or your own custom TGI
endpoint, you'll need to provide the URL of the endpoint as well as a valid token:
```python
from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.preview.dataclasses import ChatMessage
messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf",
url="<your-tgi-endpoint-url>",
token="<your-token>")
client.warm_up()
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
print(response)
```
Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with any chat-based model deployed using the TGI
framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference.
- **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face
inference endpoints. For more details, refer to https://huggingface.co/inference-endpoints.
- **Inference API Support**: Supports inference of TGI chat LLMs hosted on the rate-limited Inference
API tier. Learn more about the Inference API at https://huggingface.co/inference-api.
Discover available chat models using the following command:
```
wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat
```
and simply use the model ID as the model parameter for this component. You'll also need to provide a valid
Hugging Face API token as the token parameter.
- **Custom TGI Endpoints**: Supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can
deploy their own TGI endpoint using the TGI framework. For more details, refer
to https://huggingface.co/inference-endpoints.
Input and Output Format:
- **ChatMessage Format**: This component uses the ChatMessage format to structure both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
ChatMessage format can be found at https://github.com/openai/openai-python/blob/main/chatml.md.
"""
def __init__(
self,
model: str = "meta-llama/Llama-2-13b-chat-hf",
url: Optional[str] = None,
token: Optional[str] = None,
chat_template: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Initialize the HuggingFaceTGIChatGenerator instance.
:param model: A string representing the model path or URL. Default is "meta-llama/Llama-2-13b-chat-hf".
:param url: An optional string representing the URL of the TGI endpoint.
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
messages. While high-quality and well-supported chat models typically include their own chat templates
accessible through their tokenizer, there are models that do not offer this feature. For such scenarios,
or if you wish to use a custom template instead of the model's default, you can use this parameter to
set your preferred chat template.
:param token: The Hugging Face token for HTTP bearer authorization.
You can find your HF token at https://huggingface.co/settings/tokens.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
See Hugging Face's [documentation](https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters)
for more information.
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
if url:
r = urlparse(url)
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
if not is_valid_url:
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
check_valid_model(model, token)
# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
self.model = model
self.url = url
self.chat_template = chat_template
self.token = token
self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token)
self.streaming_callback = streaming_callback
self.tokenizer = None
def warm_up(self) -> None:
"""
Load the tokenizer.
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
chat_template = getattr(self.tokenizer, "chat_template", None)
if not chat_template and not self.chat_template:
logger.warning(
"The model '%s' doesn't have a default chat_template, and no chat_template was supplied during "
"this component's initialization. Its possible that the model doesn't support ChatML inference "
"format, potentially leading to unexpected behavior.",
self.model,
)
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: A dictionary containing the serialized component.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model=self.model,
url=self.url,
chat_template=self.chat_template,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIChatGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
# Don't send URL as it is sensitive information
return {"model": self.model}
@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:return: A list containing the generated responses as ChatMessage instances.
"""
# check generation kwargs given as parameters to override the default ones
additional_params = ["n", "stop_words"]
check_generation_params(generation_kwargs, additional_params)
# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
num_responses = generation_kwargs.pop("n", 1)
# merge stop_words and stop_sequences into a single list
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(generation_kwargs.pop("stop_words", []))
if self.tokenizer is None:
raise RuntimeError("Please call warm_up() before running LLM inference.")
# apply either model's chat template or the user-provided one
prepared_prompt: str = self.tokenizer.apply_chat_template(
conversation=messages, chat_template=self.chat_template, tokenize=False
)
prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False))
if self.streaming_callback:
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
return self._run_streaming(prepared_prompt, prompt_token_count, generation_kwargs)
return self._run_non_streaming(prepared_prompt, prompt_token_count, num_responses, generation_kwargs)
def _run_streaming(
self, prepared_prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]
) -> Dict[str, List[ChatMessage]]:
res: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
prepared_prompt, stream=True, details=True, **generation_kwargs
)
chunk = None
# pylint: disable=not-an-iterable
for chunk in res:
token: Token = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
stream_chunk = StreamingChunk(token.text, chunk_metadata)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
message = ChatMessage.from_assistant(chunk.generated_text)
message.metadata.update(
{
"finish_reason": chunk.details.finish_reason.value,
"index": 0,
"model": self.client.model,
"usage": {
"completion_tokens": chunk.details.generated_tokens,
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + chunk.details.generated_tokens,
},
}
)
return {"replies": [message]}
def _run_non_streaming(
self, prepared_prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any]
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []
for _i in range(num_responses):
tgr: TextGenerationResponse = self.client.text_generation(
prepared_prompt, details=True, **generation_kwargs
)
message = ChatMessage.from_assistant(tgr.generated_text)
message.metadata.update(
{
"finish_reason": tgr.details.finish_reason.value,
"index": _i,
"model": self.client.model,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
},
}
)
chat_messages.append(message)
return {"replies": chat_messages}

View File

@ -0,0 +1,5 @@
---
preview:
- |
Adds `HuggingFaceTGIChatGenerator` for text and chat generation. This components support remote inferencing for
Hugging Face LLMs via text-generation-inference (TGI) protocol.

View File

@ -0,0 +1,11 @@
import pytest
from haystack.preview.dataclasses import ChatMessage
@pytest.fixture
def chat_messages():
return [
ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"),
ChatMessage.from_user("Tell me about Berlin"),
]

View File

@ -0,0 +1,317 @@
from unittest.mock import patch, MagicMock, Mock
import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
@pytest.fixture
def mock_check_valid_model():
with patch(
"haystack.preview.components.generators.chat.hugging_face_tgi.check_valid_model", MagicMock(return_value=None)
) as mock:
yield mock
@pytest.fixture
def mock_text_generation():
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
mock_response = Mock()
mock_response.generated_text = "I'm fine, thanks."
details = Mock()
details.finish_reason = MagicMock(field1="value")
details.tokens = [1, 2, 3]
mock_response.details = details
mock_text_generation.return_value = mock_response
yield mock_text_generation
# used to test serialization of streaming_callback
def streaming_callback_handler(x):
return x
class TestHuggingFaceTGIChatGenerator:
@pytest.mark.unit
def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer):
model = "HuggingFaceH4/zephyr-7b-alpha"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.tokenizer is not None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
@pytest.mark.unit
def test_to_dict(self, mock_check_valid_model):
# Initialize the HuggingFaceTGIChatGenerator object with valid parameters
generator = HuggingFaceTGIChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
token="token",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
)
# Call the to_dict method
result = generator.to_dict()
init_params = result["init_parameters"]
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] is None
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
@pytest.mark.unit
def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
)
# Call the to_dict method
result = generator.to_dict()
generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
@pytest.mark.unit
def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer):
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
# Assert that the tokenizer is now initialized
assert generator.tokenizer is not None
@pytest.mark.unit
def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog):
generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
# Set chat_template to None for this specific test
mock_auto_tokenizer.chat_template = None
generator.warm_up()
# warning message should be logged
assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text
@pytest.mark.unit
def test_custom_chat_template(
self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
):
custom_chat_template = "Here goes some Jinja template"
# mocked method to check if we called apply_chat_template with the custom template
mock_auto_tokenizer.apply_chat_template = MagicMock(return_value="some_value")
generator = HuggingFaceTGIChatGenerator(chat_template=custom_chat_template)
generator.warm_up()
assert generator.chat_template == custom_chat_template
generator.run(messages=chat_messages)
assert mock_auto_tokenizer.apply_chat_template.call_count == 1
# and we indeed called apply_chat_template with the custom template
_, kwargs = mock_auto_tokenizer.apply_chat_template.call_args
assert kwargs["chat_template"] == custom_chat_template
@pytest.mark.unit
def test_initialize_with_invalid_model_path_or_url(self, mock_check_valid_model):
model = "invalid_model"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
mock_check_valid_model.side_effect = ValueError("Invalid model path or url")
with pytest.raises(ValueError):
HuggingFaceTGIChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
@pytest.mark.unit
def test_initialize_with_invalid_url(self, mock_check_valid_model):
with pytest.raises(ValueError):
HuggingFaceTGIChatGenerator(model="NousResearch/Llama-2-7b-chat-hf", url="invalid_url")
@pytest.mark.unit
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
# When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
with pytest.raises(RepositoryNotFoundError):
HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com")
@pytest.mark.unit
def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
):
model = "meta-llama/Llama-2-13b-chat-hf"
generation_kwargs = {"n": 1}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIChatGenerator(
model=model,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
response = generator.run(messages=chat_messages)
# check kwargs passed to text_generation
# note how n because it is not text generation parameter was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.unit
def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
):
model = "meta-llama/Llama-2-13b-chat-hf"
token = None
generation_kwargs = {"n": 3}
stop_words = ["stop"]
streaming_callback = None
generator = HuggingFaceTGIChatGenerator(
model=model,
token=token,
generation_kwargs=generation_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()
response = generator.run(chat_messages)
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
# note how n caused n replies to be generated
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 3
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.unit
def test_generate_text_with_stop_words(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
):
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
stop_words = ["stop", "words"]
# Generate text response with stop words
response = generator.run(chat_messages, generation_kwargs={"stop_words": stop_words})
# check kwargs passed to text_generation
# we translate stop_words to stop_sequences
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.unit
def test_generate_text_with_custom_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
):
# Create an instance of HuggingFaceRemoteGenerator with no generation parameters
generator = HuggingFaceTGIChatGenerator()
generator.warm_up()
# but then we pass them in run
generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
response = generator.run(chat_messages, generation_kwargs=generation_kwargs)
# again check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
# Assert that the response contains the generated replies and the right response
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert response["replies"][0].content == "I'm fine, thanks."
@pytest.mark.unit
def test_generate_text_with_streaming_callback(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages
):
streaming_call_count = 0
# Define the streaming callback function
def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
# Create an instance of HuggingFaceRemoteGenerator
generator = HuggingFaceTGIChatGenerator(streaming_callback=streaming_callback_fn)
generator.warm_up()
# Create a fake streamed response
# self needed here, don't remove
def mock_iter(self):
yield TextGenerationStreamResponse(
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
)
yield TextGenerationStreamResponse(
generated_text=None,
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
)
mock_response = Mock(**{"__iter__": mock_iter})
mock_text_generation.return_value = mock_response
# Generate text response with streaming callback
response = generator.run(chat_messages)
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

View File

@ -0,0 +1,21 @@
from unittest.mock import patch, MagicMock
import pytest
@pytest.fixture
def mock_auto_tokenizer():
"""
In the original mock_auto_tokenizer fixture, we were mocking the transformers.AutoTokenizer.from_pretrained
method directly, but we were not providing a return value for this method. Therefore, when from_pretrained
was called within HuggingFaceTGIChatGenerator, it returned None because that's the default behavior of a
MagicMock object when a return value isn't specified.
We will update the mock_auto_tokenizer fixture to return a MagicMock object when from_pretrained is called
in another PR. For now, we will use this fixture to mock the AutoTokenizer.from_pretrained method.
"""
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer

View File

@ -1,4 +1,5 @@
import pytest
from transformers import AutoTokenizer
from haystack.preview.dataclasses import ChatMessage, ChatRole
@ -29,19 +30,41 @@ def test_from_system_with_valid_content():
@pytest.mark.unit
def test_with_empty_content():
message = ChatMessage("", ChatRole.USER, None)
message = ChatMessage.from_user("")
assert message.content == ""
@pytest.mark.unit
def test_with_invalid_role():
with pytest.raises(TypeError):
ChatMessage("Invalid role", "invalid_role")
@pytest.mark.unit
def test_from_function_with_empty_name():
content = "Function call"
message = ChatMessage.from_function(content, "")
assert message.content == content
assert message.name == ""
@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False)
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"
@pytest.mark.integration
def test_apply_custom_chat_templating_on_chat_message():
anthropic_template = (
"{%- for message in messages %}"
"{%- if message.role == 'user' %}\n\nHuman: {{ message.content.strip() }}"
"{%- elif message.role == 'assistant' %}\n\nAssistant: {{ message.content.strip() }}"
"{%- elif message.role == 'function' %}{{ raise('anthropic does not support function calls.') }}"
"{%- elif message.role == 'system' and loop.index == 1 %}{{ message.content }}"
"{%- else %}{{ raise('Invalid message role: ' + message.role) }}"
"{%- endif %}"
"{%- endfor %}"
"\n\nAssistant:"
)
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
# could be any tokenizer, let's use the one we already likely have in cache
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False)
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"