mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
fix: Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage; serialize chat_template (#8663)
* message conversion function * hfapi w tools * right test file + hf_hub version * release note * fix for new chatmessage; serialize chat_template * feedback
This commit is contained in:
parent
2bc58d2987
commit
f4d9c2bb91
@ -25,6 +25,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_an
|
||||
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
|
||||
HFTokenStreamingHandler,
|
||||
StopWordsCriteria,
|
||||
convert_message_to_hf_format,
|
||||
deserialize_hf_model_kwargs,
|
||||
serialize_hf_model_kwargs,
|
||||
)
|
||||
@ -201,6 +202,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
||||
@ -270,9 +272,11 @@ class HuggingFaceLocalChatGenerator:
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
|
||||
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
|
||||
# Prepare the prompt for the model
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
||||
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage format, by converting the messages to
|
||||
the format expected by Hugging Face.
|
||||
|
||||
Serialize the chat_template parameter.
|
||||
@ -135,6 +135,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=lambda x: x,
|
||||
chat_template="irrelevant",
|
||||
)
|
||||
|
||||
# Call the to_dict method
|
||||
@ -146,6 +147,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
||||
assert "token" not in init_params["huggingface_pipeline_kwargs"]
|
||||
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
||||
assert init_params["chat_template"] == "irrelevant"
|
||||
|
||||
def test_from_dict(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
@ -153,6 +155,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=streaming_callback_handler,
|
||||
chat_template="irrelevant",
|
||||
)
|
||||
# Call the to_dict method
|
||||
result = generator.to_dict()
|
||||
@ -162,6 +165,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
||||
assert generator_2.streaming_callback is streaming_callback_handler
|
||||
assert generator_2.chat_template == "irrelevant"
|
||||
|
||||
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
|
||||
def test_warm_up(self, pipeline_mock, monkeypatch):
|
||||
@ -218,3 +222,36 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
chat_message = results["replies"][0]
|
||||
assert chat_message.is_from(ChatRole.ASSISTANT)
|
||||
assert chat_message.text == "Berlin is cool"
|
||||
|
||||
@patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format")
|
||||
def test_messages_conversion_is_called(self, mock_convert, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(model="fake-model")
|
||||
|
||||
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]
|
||||
|
||||
with patch.object(generator, "pipeline") as mock_pipeline:
|
||||
mock_pipeline.tokenizer.apply_chat_template.return_value = "test prompt"
|
||||
mock_pipeline.return_value = [{"generated_text": "test response"}]
|
||||
|
||||
generator.warm_up()
|
||||
generator.run(messages)
|
||||
|
||||
assert mock_convert.call_count == 2
|
||||
mock_convert.assert_any_call(messages[0])
|
||||
mock_convert.assert_any_call(messages[1])
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
||||
def test_live_run(self):
|
||||
messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")]
|
||||
|
||||
llm = HuggingFaceLocalChatGenerator(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}
|
||||
)
|
||||
llm.warm_up()
|
||||
|
||||
result = llm.run(messages)
|
||||
|
||||
assert "replies" in result
|
||||
assert isinstance(result["replies"][0], ChatMessage)
|
||||
assert "climate change" in result["replies"][0].text.lower()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user