fix: Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage; serialize chat_template (#8663)

* message conversion function

* hfapi w tools

* right test file + hf_hub version

* release note

* fix for new chatmessage; serialize chat_template

* feedback
This commit is contained in:
Stefano Fiorucci 2024-12-19 15:12:12 +01:00 committed by GitHub
parent 2bc58d2987
commit f4d9c2bb91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 1 deletions

View File

@ -25,6 +25,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_an
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
HFTokenStreamingHandler,
StopWordsCriteria,
convert_message_to_hf_format,
deserialize_hf_model_kwargs,
serialize_hf_model_kwargs,
)
@ -201,6 +202,7 @@ class HuggingFaceLocalChatGenerator:
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
chat_template=self.chat_template,
)
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
@ -270,9 +272,11 @@ class HuggingFaceLocalChatGenerator:
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
hf_messages = [convert_message_to_hf_format(message) for message in messages]
# Prepare the prompt for the model
prepared_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)
# Avoid some unnecessary warnings in the generation pipeline call

View File

@ -0,0 +1,7 @@
---
fixes:
- |
Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage format, by converting the messages to
the format expected by Hugging Face.
Serialize the chat_template parameter.

View File

@ -135,6 +135,7 @@ class TestHuggingFaceLocalChatGenerator:
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
chat_template="irrelevant",
)
# Call the to_dict method
@ -146,6 +147,7 @@ class TestHuggingFaceLocalChatGenerator:
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert "token" not in init_params["huggingface_pipeline_kwargs"]
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["chat_template"] == "irrelevant"
def test_from_dict(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
@ -153,6 +155,7 @@ class TestHuggingFaceLocalChatGenerator:
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
chat_template="irrelevant",
)
# Call the to_dict method
result = generator.to_dict()
@ -162,6 +165,7 @@ class TestHuggingFaceLocalChatGenerator:
assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
assert generator_2.chat_template == "irrelevant"
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock, monkeypatch):
@ -218,3 +222,36 @@ class TestHuggingFaceLocalChatGenerator:
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"
@patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format")
def test_messages_conversion_is_called(self, mock_convert, model_info_mock):
generator = HuggingFaceLocalChatGenerator(model="fake-model")
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]
with patch.object(generator, "pipeline") as mock_pipeline:
mock_pipeline.tokenizer.apply_chat_template.return_value = "test prompt"
mock_pipeline.return_value = [{"generated_text": "test response"}]
generator.warm_up()
generator.run(messages)
assert mock_convert.call_count == 2
mock_convert.assert_any_call(messages[0])
mock_convert.assert_any_call(messages[1])
@pytest.mark.integration
@pytest.mark.flaky(reruns=3, reruns_delay=10)
def test_live_run(self):
messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")]
llm = HuggingFaceLocalChatGenerator(
model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}
)
llm.warm_up()
result = llm.run(messages)
assert "replies" in result
assert isinstance(result["replies"][0], ChatMessage)
assert "climate change" in result["replies"][0].text.lower()