refactor: Refactor HuggingFaceLocalChatGenerator (#9455)

* Refactoring to better align run and run_async and reduce duplicate code

* Docstrings and align run and run_async

* More changes

* add missing type

* Refactor async part a bit

* Fix import error

* Fix mypy
This commit is contained in:
Sebastian Husch Lee 2025-06-13 15:38:00 +02:00 committed by GitHub
parent 379df4ab84
commit c5027d711c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,7 +7,7 @@ import json
import re
import sys
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from contextlib import asynccontextmanager, suppress
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
from haystack import component, default_from_dict, default_to_dict, logging
@ -138,7 +138,7 @@ class HuggingFaceLocalChatGenerator:
tools: Optional[Union[List[Tool], Toolset]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
async_executor: Optional[ThreadPoolExecutor] = None,
):
) -> None:
"""
Initializes the HuggingFaceLocalChatGenerator component.
@ -249,14 +249,14 @@ class HuggingFaceLocalChatGenerator:
else async_executor
)
def __del__(self):
def __del__(self) -> None:
"""
Cleanup when the instance is being destroyed.
"""
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
if self._owns_executor:
self.executor.shutdown(wait=True)
def shutdown(self):
def shutdown(self) -> None:
"""
Explicitly shutdown the executor if we own it.
"""
@ -271,7 +271,7 @@ class HuggingFaceLocalChatGenerator:
return {"model": self.huggingface_pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
def warm_up(self):
def warm_up(self) -> None:
"""
Initializes the component.
"""
@ -336,105 +336,41 @@ class HuggingFaceLocalChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
) -> Dict[str, List[ChatMessage]]:
"""
Invoke text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools:
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
:param tools: A list of tools or a Toolset for which the model can prepare calls. If set, it will override
the `tools` parameter provided during initialization. This parameter can accept either a list
of `Tool` objects or a `Toolset` instance.
:returns:
A list containing the generated responses as ChatMessage instances.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
tools = tools or self.tools
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(list(tools or []))
tokenizer = self.pipeline.tokenizer
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
assert tokenizer is not None
# Check and update generation parameters
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
# pipeline call doesn't support stop_sequences, so we need to pop it
stop_words = self._validate_stop_words(stop_words)
# Set up stop words criteria if stop words exist
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
prepared_inputs = self._prepare_inputs(
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
)
streaming_callback = select_streaming_callback(
self.streaming_callback, streaming_callback, requires_async=False
)
if streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
msg = (
"Streaming is enabled, but the number of responses is set to {num_responses}. "
"Streaming is only supported for single response generation. "
"Setting the number of responses to 1."
)
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1
# Get component name and type
component_info = ComponentInfo.from_component(self)
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(
tokenizer=tokenizer,
prepared_inputs["generation_kwargs"]["streamer"] = HFTokenStreamingHandler(
tokenizer=prepared_inputs["tokenizer"],
stream_handler=streaming_callback,
stop_words=stop_words,
component_info=component_info,
stop_words=prepared_inputs["stop_words"],
component_info=ComponentInfo.from_component(self),
)
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
if isinstance(tools, Toolset):
tools = list(tools)
prepared_prompt = tokenizer.apply_chat_template(
hf_messages,
tokenize=False,
chat_template=self.chat_template,
add_generation_prompt=True,
tools=[tc.tool_spec for tc in tools] if tools else None,
)
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
assert isinstance(prepared_prompt, str)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# We know it's not None because we check it in _prepare_inputs
assert self.pipeline is not None
# Generate responses
output = self.pipeline(prepared_prompt, **generation_kwargs)
replies = [o.get("generated_text", "") for o in output]
output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
# Remove stop words from replies if present
if stop_words:
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
chat_messages = [
self.create_message(
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
)
for r_index, reply in enumerate(replies)
]
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
return {"replies": chat_messages}
@ -512,7 +448,7 @@ class HuggingFaceLocalChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
) -> Dict[str, List[ChatMessage]]:
"""
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
@ -527,6 +463,77 @@ class HuggingFaceLocalChatGenerator:
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
prepared_inputs = self._prepare_inputs(
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
)
# Validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
if streaming_callback:
async_handler = AsyncHFTokenStreamingHandler(
tokenizer=prepared_inputs["tokenizer"],
# Cast to AsyncStreamingCallbackT since we know streaming_callback is async
stream_handler=cast(AsyncStreamingCallbackT, streaming_callback),
stop_words=prepared_inputs["stop_words"],
component_info=ComponentInfo.from_component(self),
)
prepared_inputs["generation_kwargs"]["streamer"] = async_handler
# Use async context manager for proper resource cleanup
async with self._manage_queue_processor(async_handler):
output = await asyncio.get_running_loop().run_in_executor(
self.executor,
# have to ignore since assert self.pipeline is not None doesn't work
lambda: self.pipeline( # type: ignore[misc]
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
),
)
else:
output = await asyncio.get_running_loop().run_in_executor(
self.executor,
# have to ignore since assert self.pipeline is not None doesn't work
lambda: self.pipeline( # type: ignore[misc]
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
),
)
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
return {"replies": chat_messages}
@asynccontextmanager
async def _manage_queue_processor(self, async_handler: "AsyncHFTokenStreamingHandler"):
"""Context manager for proper queue processor lifecycle management."""
queue_processor = asyncio.create_task(async_handler.process_queue())
try:
yield queue_processor
finally:
# Ensure the queue processor is cleaned up properly
try:
await asyncio.wait_for(queue_processor, timeout=0.1)
except asyncio.TimeoutError:
queue_processor.cancel()
with suppress(asyncio.CancelledError):
await queue_processor
def _prepare_inputs(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
) -> Dict[str, Any]:
"""
Prepares the inputs for the Hugging Face pipeline.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools or a Toolset for which the model can prepare calls.
:returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
:raises RuntimeError: If the generation model has not been loaded.
:raises ValueError: If both tools and streaming_callback are provided.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
@ -535,6 +542,9 @@ class HuggingFaceLocalChatGenerator:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(list(tools or []))
if isinstance(tools, Toolset):
tools = list(tools)
tokenizer = self.pipeline.tokenizer
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
assert tokenizer is not None
@ -542,6 +552,18 @@ class HuggingFaceLocalChatGenerator:
# Check and update generation parameters
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# If streaming_callback is provided, ensure that num_return_sequences is set to 1
if streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
msg = (
"Streaming is enabled, but the number of responses is set to {num_responses}. "
"Streaming is only supported for single response generation. "
"Setting the number of responses to 1."
)
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
stop_words = self._validate_stop_words(stop_words)
@ -550,99 +572,8 @@ class HuggingFaceLocalChatGenerator:
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
# validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
if streaming_callback:
return await self._run_streaming_async(
messages, tokenizer, generation_kwargs, stop_words, streaming_callback
)
return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
self,
messages: List[ChatMessage],
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
streaming_callback: AsyncStreamingCallbackT,
):
"""
Handles async streaming generation of responses.
"""
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
prepared_prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
assert isinstance(prepared_prompt, str)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# get the component name and type
component_info = ComponentInfo.from_component(self)
async_handler = AsyncHFTokenStreamingHandler(
tokenizer=tokenizer, stream_handler=streaming_callback, stop_words=stop_words, component_info=component_info
)
generation_kwargs["streamer"] = async_handler
# Start queue processing in the background
queue_processor = asyncio.create_task(async_handler.process_queue())
try:
# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(
self.executor,
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
)
replies = [o.get("generated_text", "") for o in output]
# Remove stop words from replies if present
if stop_words:
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
chat_messages = [
self.create_message(
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False
)
for r_index, reply in enumerate(replies)
]
return {"replies": chat_messages}
finally:
try:
await asyncio.wait_for(queue_processor, timeout=0.1)
except asyncio.TimeoutError:
queue_processor.cancel()
with suppress(asyncio.CancelledError):
await queue_processor
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
self,
messages: List[ChatMessage],
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Handles async non-streaming generation of responses.
"""
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
if isinstance(tools, Toolset):
tools = list(tools)
prepared_prompt = tokenizer.apply_chat_template(
hf_messages,
@ -651,22 +582,44 @@ class HuggingFaceLocalChatGenerator:
add_generation_prompt=True,
tools=[tc.tool_spec for tc in tools] if tools else None,
)
# prepared_prompt is a string, but transformers has some type issues
prepared_prompt = cast(str, prepared_prompt)
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
assert isinstance(prepared_prompt, str)
# Avoid some unnecessary warnings in the generation pipeline call
generation_kwargs["pad_token_id"] = (
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)
# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(
self.executor,
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
)
return {
"prepared_prompt": prepared_prompt,
"tokenizer": tokenizer,
"generation_kwargs": generation_kwargs,
"tools": tools,
"stop_words": stop_words,
}
replies = [o.get("generated_text", "") for o in output]
def _convert_hf_output_to_chat_messages(
self,
*,
hf_pipeline_output: List[Dict[str, Any]],
prepared_prompt: str,
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
tools: Optional[Union[List[Tool], Toolset]] = None,
) -> List[ChatMessage]:
"""
Converts the HuggingFace pipeline output into a List of ChatMessages
:param hf_pipeline_output: The output from the HuggingFace pipeline.
:param prepared_prompt: The prompt used for generation.
:param tokenizer: The tokenizer used for generation.
:param generation_kwargs: The generation parameters.
:param stop_words: A list of stop words to remove from the replies.
:param tools: A list of tools or a Toolset for which the model can prepare calls.
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
"""
replies = [o.get("generated_text", "") for o in hf_pipeline_output]
# Remove stop words from replies if present
if stop_words:
@ -675,9 +628,13 @@ class HuggingFaceLocalChatGenerator:
chat_messages = [
self.create_message(
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
text=reply,
index=r_index,
tokenizer=tokenizer,
prompt=prepared_prompt,
generation_kwargs=generation_kwargs,
parse_tool_calls=bool(tools),
)
for r_index, reply in enumerate(replies)
]
return {"replies": chat_messages}
return chat_messages