mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 15:38:36 +00:00
feat: async support for the HuggingFaceLocalChatGenerator (#8981)
* adding async run method * passing an optional ThreadExecutor * adding tests * adding release notes * nit: license * fixing linting * Update releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Use Phi isntead (#8982) * build: drop Python 3.8 support (#8978) * draft * readd typing_extensions * small fix + release note * remove ruff target-version * Update releasenotes/notes/drop-python-3.8-868710963e794c83.yaml Co-authored-by: David S. Batista <dsbatista@gmail.com> --------- Co-authored-by: David S. Batista <dsbatista@gmail.com> * Update unstable version to 2.12.0-rc0 (#8983) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * fix: allow support for `include_usage` in streaming using OpenAIChatGenerator (#8968) * fix error in handling usage completion chunk * ci: improve release notes format checking (#8984) * chore: fix invalid release note * try improving relnote linting * add relnotes path * fix bad release note * improve reno config * fix: handle async tests in`HuggingFaceAPIChatGenerator` to prevent error (#8986) * add missing asyncio * explicitly close connection in the test * Fix tests (#8990) * docs: Update docstrings of `BranchJoiner` (#8988) * Update docstrings * Add a bit more explanatory text * Add reno * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Fix formatting --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * PR comments * destroying ThreadPoolExecutor when the generator instance is being destroyied, only if it was not passed externally * fixing bug in streaming_callback * PR comments --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com> Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> Co-authored-by: Haystack Bot <73523382+HaystackBot@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
c4fafd9b04
commit
4c9d08add5
@ -2,13 +2,15 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.utils import (
|
||||
@ -123,6 +125,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the HuggingFaceLocalChatGenerator component.
|
||||
@ -165,6 +168,9 @@ class HuggingFaceLocalChatGenerator:
|
||||
:param tool_parsing_function:
|
||||
A callable that takes a string and returns a list of ToolCall objects or None.
|
||||
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
|
||||
:param async_executor:
|
||||
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
|
||||
initialized and used
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
@ -223,6 +229,27 @@ class HuggingFaceLocalChatGenerator:
|
||||
self.pipeline = None
|
||||
self.tools = tools
|
||||
|
||||
self._owns_executor = async_executor is None
|
||||
self.executor = (
|
||||
ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
|
||||
if async_executor is None
|
||||
else async_executor
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Cleanup when the instance is being destroyed.
|
||||
"""
|
||||
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
|
||||
self.executor.shutdown(wait=True)
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Explicitly shutdown the executor if we own it.
|
||||
"""
|
||||
if self._owns_executor:
|
||||
self.executor.shutdown(wait=True)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
@ -332,7 +359,9 @@ class HuggingFaceLocalChatGenerator:
|
||||
if stop_words_criteria:
|
||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
||||
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
streaming_callback = select_streaming_callback(
|
||||
self.streaming_callback, streaming_callback, requires_async=False
|
||||
)
|
||||
if streaming_callback:
|
||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
@ -427,7 +456,8 @@ class HuggingFaceLocalChatGenerator:
|
||||
# If tool calls are detected, don't include the text content since it contains the raw tool call format
|
||||
return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
|
||||
|
||||
def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
|
||||
@staticmethod
|
||||
def _validate_stop_words(stop_words: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""
|
||||
Validates the provided stop words.
|
||||
|
||||
@ -443,3 +473,146 @@ class HuggingFaceLocalChatGenerator:
|
||||
return None
|
||||
|
||||
return list(set(stop_words or []))
|
||||
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
async def run_async(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
):
|
||||
"""
|
||||
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
This is the asynchronous version of the `run` method. It has the same parameters
|
||||
and return values but can be used with `await` in an async code.
|
||||
|
||||
:param messages: A list of ChatMessage objects representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
:param tools: A list of tools for which the model can prepare calls.
|
||||
:returns: A dictionary with the following keys:
|
||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
tools = tools or self.tools
|
||||
if tools and streaming_callback is not None:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(tools)
|
||||
|
||||
tokenizer = self.pipeline.tokenizer
|
||||
|
||||
# Check and update generation parameters
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
||||
stop_words = self._validate_stop_words(stop_words)
|
||||
|
||||
# Set up stop words criteria if stop words exist
|
||||
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
|
||||
if stop_words_criteria:
|
||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
||||
|
||||
# validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
||||
|
||||
if streaming_callback:
|
||||
return await self._run_streaming_async(
|
||||
messages, tokenizer, generation_kwargs, stop_words, streaming_callback
|
||||
)
|
||||
|
||||
return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
|
||||
|
||||
async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
streaming_callback: Callable[[StreamingChunk], None],
|
||||
):
|
||||
"""
|
||||
Handles async streaming generation of responses.
|
||||
"""
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
generation_kwargs["pad_token_id"] = (
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Set up streaming handler
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
|
||||
|
||||
# Generate responses asynchronously
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
|
||||
)
|
||||
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
|
||||
# Remove stop words from replies if present
|
||||
for stop_word in stop_words or []:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
||||
|
||||
chat_messages = [
|
||||
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False)
|
||||
for r_index, reply in enumerate(replies)
|
||||
]
|
||||
|
||||
return {"replies": chat_messages}
|
||||
|
||||
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
tools: Optional[List[Tool]] = None,
|
||||
):
|
||||
"""
|
||||
Handles async non-streaming generation of responses.
|
||||
"""
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages,
|
||||
tokenize=False,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=True,
|
||||
tools=[tc.tool_spec for tc in tools] if tools else None,
|
||||
)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
generation_kwargs["pad_token_id"] = (
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Generate responses asynchronously
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
|
||||
)
|
||||
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
|
||||
# Remove stop words from replies if present
|
||||
for stop_word in stop_words or []:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
||||
|
||||
chat_messages = [
|
||||
self.create_message(
|
||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
|
||||
)
|
||||
for r_index, reply in enumerate(replies)
|
||||
]
|
||||
|
||||
return {"replies": chat_messages}
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add `run_async` method to HuggingFaceLocalChatGenerator. This method internally uses ThreadPoolExecutor to return coroutines
|
||||
that can be awaited.
|
||||
@ -1,18 +1,21 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import Optional, List
|
||||
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
import asyncio
|
||||
import gc
|
||||
from typing import Optional, List
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
||||
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
from haystack.tools import Tool
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.tools import Tool
|
||||
|
||||
|
||||
# used to test serialization of streaming_callback
|
||||
@ -474,3 +477,76 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
assert len(results["replies"][0].tool_calls) == 1
|
||||
assert results["replies"][0].tool_calls[0].tool_name == "weather"
|
||||
assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"}
|
||||
|
||||
# Async tests
|
||||
|
||||
async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
|
||||
"""Test basic async functionality"""
|
||||
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
|
||||
results = await generator.run_async(messages=chat_messages)
|
||||
|
||||
assert "replies" in results
|
||||
assert isinstance(results["replies"][0], ChatMessage)
|
||||
chat_message = results["replies"][0]
|
||||
assert chat_message.is_from(ChatRole.ASSISTANT)
|
||||
assert chat_message.text == "Berlin is cool"
|
||||
|
||||
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools):
|
||||
"""Test async functionality with tools"""
|
||||
generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools)
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
# Mock the pipeline to return a tool call format
|
||||
generator.pipeline.return_value = [{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}]
|
||||
|
||||
messages = [ChatMessage.from_user("What's the weather in Berlin?")]
|
||||
results = await generator.run_async(messages=messages)
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
message = results["replies"][0]
|
||||
assert message.tool_calls
|
||||
tool_call = message.tool_calls[0]
|
||||
assert isinstance(tool_call, ToolCall)
|
||||
assert tool_call.tool_name == "weather"
|
||||
assert tool_call.arguments == {"city": "Berlin"}
|
||||
|
||||
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
|
||||
"""Test handling of multiple concurrent async requests"""
|
||||
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = [generator.run_async(messages=chat_messages) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for result in results:
|
||||
assert "replies" in result
|
||||
assert isinstance(result["replies"][0], ChatMessage)
|
||||
assert result["replies"][0].text == "Berlin is cool"
|
||||
|
||||
async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer):
|
||||
"""Test error handling in async context"""
|
||||
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
|
||||
|
||||
# Test without warm_up
|
||||
with pytest.raises(RuntimeError, match="The generation model has not been loaded"):
|
||||
await generator.run_async(messages=[ChatMessage.from_user("test")])
|
||||
|
||||
# Test with invalid streaming callback
|
||||
generator.pipeline = mock_pipeline_tokenizer
|
||||
with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"):
|
||||
await generator.run_async(
|
||||
messages=[ChatMessage.from_user("test")],
|
||||
streaming_callback=lambda x: None,
|
||||
tools=[Tool(name="test", description="test", parameters={}, function=lambda: None)],
|
||||
)
|
||||
|
||||
def test_executor_shutdown(self, model_info_mock, mock_pipeline_tokenizer):
|
||||
with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline:
|
||||
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
|
||||
executor = generator.executor
|
||||
with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown:
|
||||
del generator
|
||||
gc.collect()
|
||||
mock_shutdown.assert_called_once_with(wait=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user