feat: async support for the HuggingFaceLocalChatGenerator (#8981)

* adding async run method

* passing an optional ThreadExecutor

* adding tests

* adding release notes

* nit: license

* fixing linting

* Update releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

* Use Phi isntead (#8982)

* build: drop Python 3.8 support (#8978)

* draft

* readd typing_extensions

* small fix + release note

* remove ruff target-version

* Update releasenotes/notes/drop-python-3.8-868710963e794c83.yaml

Co-authored-by: David S. Batista <dsbatista@gmail.com>

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>

* Update unstable version to 2.12.0-rc0 (#8983)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* fix: allow support for `include_usage` in streaming using OpenAIChatGenerator (#8968)

* fix error in handling usage completion chunk

* ci: improve release notes format checking (#8984)

* chore: fix invalid release note

* try improving relnote linting

* add relnotes path

* fix bad release note

* improve reno config

* fix: handle async tests in`HuggingFaceAPIChatGenerator` to prevent error (#8986)

* add missing asyncio

* explicitly close connection in the test

* Fix tests (#8990)

* docs: Update docstrings of `BranchJoiner` (#8988)

* Update docstrings

* Add a bit more explanatory text

* Add reno

* Update haystack/components/joiners/branch.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/joiners/branch.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/joiners/branch.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/joiners/branch.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Fix formatting

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* PR comments

* destroying ThreadPoolExecutor when the generator instance is being destroyied, only if it was not passed externally

* fixing bug in streaming_callback

* PR comments

---------

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
Co-authored-by: Haystack Bot <73523382+HaystackBot@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
David S. Batista 2025-03-06 15:57:11 +01:00 committed by GitHub
parent c4fafd9b04
commit 4c9d08add5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 261 additions and 7 deletions

View File

@ -2,13 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import re
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import (
@ -123,6 +125,7 @@ class HuggingFaceLocalChatGenerator:
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
async_executor: Optional[ThreadPoolExecutor] = None,
):
"""
Initializes the HuggingFaceLocalChatGenerator component.
@ -165,6 +168,9 @@ class HuggingFaceLocalChatGenerator:
:param tool_parsing_function:
A callable that takes a string and returns a list of ToolCall objects or None.
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
:param async_executor:
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
initialized and used
"""
torch_and_transformers_import.check()
@ -223,6 +229,27 @@ class HuggingFaceLocalChatGenerator:
self.pipeline = None
self.tools = tools
self._owns_executor = async_executor is None
self.executor = (
ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
if async_executor is None
else async_executor
)
def __del__(self):
"""
Cleanup when the instance is being destroyed.
"""
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
self.executor.shutdown(wait=True)
def shutdown(self):
"""
Explicitly shutdown the executor if we own it.
"""
if self._owns_executor:
self.executor.shutdown(wait=True)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
@ -332,7 +359,9 @@ class HuggingFaceLocalChatGenerator:
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
streaming_callback = streaming_callback or self.streaming_callback
streaming_callback = select_streaming_callback(
self.streaming_callback, streaming_callback, requires_async=False
)
if streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
@ -427,7 +456,8 @@ class HuggingFaceLocalChatGenerator:
# If tool calls are detected, don't include the text content since it contains the raw tool call format
return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
@staticmethod
def _validate_stop_words(stop_words: Optional[List[str]]) -> Optional[List[str]]:
"""
Validates the provided stop words.
@ -443,3 +473,146 @@ class HuggingFaceLocalChatGenerator:
return None
return list(set(stop_words or []))
@component.output_types(replies=List[ChatMessage])
async def run_async(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
This is the asynchronous version of the `run` method. It has the same parameters
and return values but can be used with `await` in an async code.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools for which the model can prepare calls.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
tools = tools or self.tools
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
tokenizer = self.pipeline.tokenizer
# Check and update generation parameters
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
stop_words = self._validate_stop_words(stop_words)
# Set up stop words criteria if stop words exist
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
# validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
if streaming_callback:
return await self._run_streaming_async(
messages, tokenizer, generation_kwargs, stop_words, streaming_callback
)
return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
self,
messages: List[ChatMessage],
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
streaming_callback: Callable[[StreamingChunk], None],
):
"""
Handles async streaming generation of responses.
"""
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
prepared_prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# Set up streaming handler
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(
self.executor,
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
)
replies = [o.get("generated_text", "") for o in output]
# Remove stop words from replies if present
for stop_word in stop_words or []:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
chat_messages = [
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False)
for r_index, reply in enumerate(replies)
]
return {"replies": chat_messages}
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
self,
messages: List[ChatMessage],
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
tools: Optional[List[Tool]] = None,
):
"""
Handles async non-streaming generation of responses.
"""
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
prepared_prompt = tokenizer.apply_chat_template(
hf_messages,
tokenize=False,
chat_template=self.chat_template,
add_generation_prompt=True,
tools=[tc.tool_spec for tc in tools] if tools else None,
)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(
self.executor,
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
)
replies = [o.get("generated_text", "") for o in output]
# Remove stop words from replies if present
for stop_word in stop_words or []:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
chat_messages = [
self.create_message(
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
)
for r_index, reply in enumerate(replies)
]
return {"replies": chat_messages}

View File

@ -0,0 +1,5 @@
---
features:
- |
Add `run_async` method to HuggingFaceLocalChatGenerator. This method internally uses ThreadPoolExecutor to return coroutines
that can be awaited.

View File

@ -1,18 +1,21 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch
from typing import Optional, List
from haystack.dataclasses.streaming_chunk import StreamingChunk
import asyncio
import gc
from typing import Optional, List
from unittest.mock import Mock, patch
import pytest
from transformers import PreTrainedTokenizer
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
from haystack.tools import Tool
# used to test serialization of streaming_callback
@ -474,3 +477,76 @@ class TestHuggingFaceLocalChatGenerator:
assert len(results["replies"][0].tool_calls) == 1
assert results["replies"][0].tool_calls[0].tool_name == "weather"
assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"}
# Async tests
async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
"""Test basic async functionality"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
generator.pipeline = mock_pipeline_tokenizer
results = await generator.run_async(messages=chat_messages)
assert "replies" in results
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools):
"""Test async functionality with tools"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools)
generator.pipeline = mock_pipeline_tokenizer
# Mock the pipeline to return a tool call format
generator.pipeline.return_value = [{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}]
messages = [ChatMessage.from_user("What's the weather in Berlin?")]
results = await generator.run_async(messages=messages)
assert len(results["replies"]) == 1
message = results["replies"][0]
assert message.tool_calls
tool_call = message.tool_calls[0]
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Berlin"}
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
"""Test handling of multiple concurrent async requests"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
generator.pipeline = mock_pipeline_tokenizer
# Create multiple concurrent requests
tasks = [generator.run_async(messages=chat_messages) for _ in range(5)]
results = await asyncio.gather(*tasks)
for result in results:
assert "replies" in result
assert isinstance(result["replies"][0], ChatMessage)
assert result["replies"][0].text == "Berlin is cool"
async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer):
"""Test error handling in async context"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
# Test without warm_up
with pytest.raises(RuntimeError, match="The generation model has not been loaded"):
await generator.run_async(messages=[ChatMessage.from_user("test")])
# Test with invalid streaming callback
generator.pipeline = mock_pipeline_tokenizer
with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"):
await generator.run_async(
messages=[ChatMessage.from_user("test")],
streaming_callback=lambda x: None,
tools=[Tool(name="test", description="test", parameters={}, function=lambda: None)],
)
def test_executor_shutdown(self, model_info_mock, mock_pipeline_tokenizer):
with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline:
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
executor = generator.executor
with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown:
del generator
gc.collect()
mock_shutdown.assert_called_once_with(wait=True)