From 4c9d08add5ca82600153264ea7b2a9d0fcecadcd Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 6 Mar 2025 15:57:11 +0100 Subject: [PATCH] feat: async support for the `HuggingFaceLocalChatGenerator` (#8981) * adding async run method * passing an optional ThreadExecutor * adding tests * adding release notes * nit: license * fixing linting * Update releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml Co-authored-by: Amna Mubashar * Use Phi isntead (#8982) * build: drop Python 3.8 support (#8978) * draft * readd typing_extensions * small fix + release note * remove ruff target-version * Update releasenotes/notes/drop-python-3.8-868710963e794c83.yaml Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista * Update unstable version to 2.12.0-rc0 (#8983) Co-authored-by: github-actions[bot] * fix: allow support for `include_usage` in streaming using OpenAIChatGenerator (#8968) * fix error in handling usage completion chunk * ci: improve release notes format checking (#8984) * chore: fix invalid release note * try improving relnote linting * add relnotes path * fix bad release note * improve reno config * fix: handle async tests in`HuggingFaceAPIChatGenerator` to prevent error (#8986) * add missing asyncio * explicitly close connection in the test * Fix tests (#8990) * docs: Update docstrings of `BranchJoiner` (#8988) * Update docstrings * Add a bit more explanatory text * Add reno * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina * Fix formatting --------- Co-authored-by: Daria Fokina * PR comments * destroying ThreadPoolExecutor when the generator instance is being destroyied, only if it was not passed externally * fixing bug in streaming_callback * PR comments --------- Co-authored-by: Amna Mubashar Co-authored-by: Sebastian Husch Lee Co-authored-by: Stefano Fiorucci Co-authored-by: Haystack Bot <73523382+HaystackBot@users.noreply.github.com> Co-authored-by: github-actions[bot] Co-authored-by: Daria Fokina --- .../generators/chat/hugging_face_local.py | 179 +++++++++++++++++- ...local-chat-generator-962512f52282d12d.yaml | 5 + .../chat/test_hugging_face_local.py | 84 +++++++- 3 files changed, 261 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index f54c50181..e4ad4ea80 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -2,13 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import json import re import sys +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, List, Literal, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import ( @@ -123,6 +125,7 @@ class HuggingFaceLocalChatGenerator: streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, tools: Optional[List[Tool]] = None, tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, + async_executor: Optional[ThreadPoolExecutor] = None, ): """ Initializes the HuggingFaceLocalChatGenerator component. @@ -165,6 +168,9 @@ class HuggingFaceLocalChatGenerator: :param tool_parsing_function: A callable that takes a string and returns a list of ToolCall objects or None. If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern. + :param async_executor: + Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be + initialized and used """ torch_and_transformers_import.check() @@ -223,6 +229,27 @@ class HuggingFaceLocalChatGenerator: self.pipeline = None self.tools = tools + self._owns_executor = async_executor is None + self.executor = ( + ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1) + if async_executor is None + else async_executor + ) + + def __del__(self): + """ + Cleanup when the instance is being destroyed. + """ + if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"): + self.executor.shutdown(wait=True) + + def shutdown(self): + """ + Explicitly shutdown the executor if we own it. + """ + if self._owns_executor: + self.executor.shutdown(wait=True) + def _get_telemetry_data(self) -> Dict[str, Any]: """ Data that is sent to Posthog for usage analytics. @@ -332,7 +359,9 @@ class HuggingFaceLocalChatGenerator: if stop_words_criteria: generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria]) - streaming_callback = streaming_callback or self.streaming_callback + streaming_callback = select_streaming_callback( + self.streaming_callback, streaming_callback, requires_async=False + ) if streaming_callback: num_responses = generation_kwargs.get("num_return_sequences", 1) if num_responses > 1: @@ -427,7 +456,8 @@ class HuggingFaceLocalChatGenerator: # If tool calls are detected, don't include the text content since it contains the raw tool call format return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta) - def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]: + @staticmethod + def _validate_stop_words(stop_words: Optional[List[str]]) -> Optional[List[str]]: """ Validates the provided stop words. @@ -443,3 +473,146 @@ class HuggingFaceLocalChatGenerator: return None return list(set(stop_words or [])) + + @component.output_types(replies=List[ChatMessage]) + async def run_async( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Asynchronously invokes text generation inference based on the provided messages and generation parameters. + + This is the asynchronous version of the `run` method. It has the same parameters + and return values but can be used with `await` in an async code. + + :param messages: A list of ChatMessage objects representing the input messages. + :param generation_kwargs: Additional keyword arguments for text generation. + :param streaming_callback: An optional callable for handling streaming responses. + :param tools: A list of tools for which the model can prepare calls. + :returns: A dictionary with the following keys: + - `replies`: A list containing the generated responses as ChatMessage instances. + """ + if self.pipeline is None: + raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") + + tools = tools or self.tools + if tools and streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + + tokenizer = self.pipeline.tokenizer + + # Check and update generation parameters + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", []) + stop_words = self._validate_stop_words(stop_words) + + # Set up stop words criteria if stop words exist + stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None + if stop_words_criteria: + generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria]) + + # validate and select the streaming callback + streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) + + if streaming_callback: + return await self._run_streaming_async( + messages, tokenizer, generation_kwargs, stop_words, streaming_callback + ) + + return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools) + + async def _run_streaming_async( # pylint: disable=too-many-positional-arguments + self, + messages: List[ChatMessage], + tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], + generation_kwargs: Dict[str, Any], + stop_words: Optional[List[str]], + streaming_callback: Callable[[StreamingChunk], None], + ): + """ + Handles async streaming generation of responses. + """ + # convert messages to HF format + hf_messages = [convert_message_to_hf_format(message) for message in messages] + prepared_prompt = tokenizer.apply_chat_template( + hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True + ) + + # Avoid some unnecessary warnings in the generation pipeline call + generation_kwargs["pad_token_id"] = ( + generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id + ) + + # Set up streaming handler + generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) + + # Generate responses asynchronously + output = await asyncio.get_running_loop().run_in_executor( + self.executor, + lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init + ) + + replies = [o.get("generated_text", "") for o in output] + + # Remove stop words from replies if present + for stop_word in stop_words or []: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + + chat_messages = [ + self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False) + for r_index, reply in enumerate(replies) + ] + + return {"replies": chat_messages} + + async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments + self, + messages: List[ChatMessage], + tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], + generation_kwargs: Dict[str, Any], + stop_words: Optional[List[str]], + tools: Optional[List[Tool]] = None, + ): + """ + Handles async non-streaming generation of responses. + """ + # convert messages to HF format + hf_messages = [convert_message_to_hf_format(message) for message in messages] + prepared_prompt = tokenizer.apply_chat_template( + hf_messages, + tokenize=False, + chat_template=self.chat_template, + add_generation_prompt=True, + tools=[tc.tool_spec for tc in tools] if tools else None, + ) + + # Avoid some unnecessary warnings in the generation pipeline call + generation_kwargs["pad_token_id"] = ( + generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id + ) + + # Generate responses asynchronously + output = await asyncio.get_running_loop().run_in_executor( + self.executor, + lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init + ) + + replies = [o.get("generated_text", "") for o in output] + + # Remove stop words from replies if present + for stop_word in stop_words or []: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + + chat_messages = [ + self.create_message( + reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools) + ) + for r_index, reply in enumerate(replies) + ] + + return {"replies": chat_messages} diff --git a/releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml b/releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml new file mode 100644 index 000000000..c2323a512 --- /dev/null +++ b/releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add `run_async` method to HuggingFaceLocalChatGenerator. This method internally uses ThreadPoolExecutor to return coroutines + that can be awaited. diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 38d25ec91..6a173a4ab 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -1,18 +1,21 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import Mock, patch -from typing import Optional, List -from haystack.dataclasses.streaming_chunk import StreamingChunk +import asyncio +import gc +from typing import Optional, List +from unittest.mock import Mock, patch + import pytest from transformers import PreTrainedTokenizer from haystack.components.generators.chat import HuggingFaceLocalChatGenerator from haystack.dataclasses import ChatMessage, ChatRole, ToolCall +from haystack.dataclasses.streaming_chunk import StreamingChunk +from haystack.tools import Tool from haystack.utils import ComponentDevice from haystack.utils.auth import Secret -from haystack.tools import Tool # used to test serialization of streaming_callback @@ -474,3 +477,76 @@ class TestHuggingFaceLocalChatGenerator: assert len(results["replies"][0].tool_calls) == 1 assert results["replies"][0].tool_calls[0].tool_name == "weather" assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"} + + # Async tests + + async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + """Test basic async functionality""" + generator = HuggingFaceLocalChatGenerator(model="mocked-model") + generator.pipeline = mock_pipeline_tokenizer + + results = await generator.run_async(messages=chat_messages) + + assert "replies" in results + assert isinstance(results["replies"][0], ChatMessage) + chat_message = results["replies"][0] + assert chat_message.is_from(ChatRole.ASSISTANT) + assert chat_message.text == "Berlin is cool" + + async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools): + """Test async functionality with tools""" + generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools) + generator.pipeline = mock_pipeline_tokenizer + # Mock the pipeline to return a tool call format + generator.pipeline.return_value = [{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}] + + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + results = await generator.run_async(messages=messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.tool_calls + tool_call = message.tool_calls[0] + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Berlin"} + + async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + """Test handling of multiple concurrent async requests""" + generator = HuggingFaceLocalChatGenerator(model="mocked-model") + generator.pipeline = mock_pipeline_tokenizer + + # Create multiple concurrent requests + tasks = [generator.run_async(messages=chat_messages) for _ in range(5)] + results = await asyncio.gather(*tasks) + + for result in results: + assert "replies" in result + assert isinstance(result["replies"][0], ChatMessage) + assert result["replies"][0].text == "Berlin is cool" + + async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer): + """Test error handling in async context""" + generator = HuggingFaceLocalChatGenerator(model="mocked-model") + + # Test without warm_up + with pytest.raises(RuntimeError, match="The generation model has not been loaded"): + await generator.run_async(messages=[ChatMessage.from_user("test")]) + + # Test with invalid streaming callback + generator.pipeline = mock_pipeline_tokenizer + with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"): + await generator.run_async( + messages=[ChatMessage.from_user("test")], + streaming_callback=lambda x: None, + tools=[Tool(name="test", description="test", parameters={}, function=lambda: None)], + ) + + def test_executor_shutdown(self, model_info_mock, mock_pipeline_tokenizer): + with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline: + generator = HuggingFaceLocalChatGenerator(model="mocked-model") + executor = generator.executor + with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown: + del generator + gc.collect() + mock_shutdown.assert_called_once_with(wait=True)