mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
refactor: Refactor HuggingFaceLocalChatGenerator
(#9455)
* Refactoring to better align run and run_async and reduce duplicate code * Docstrings and align run and run_async * More changes * add missing type * Refactor async part a bit * Fix import error * Fix mypy
This commit is contained in:
parent
379df4ab84
commit
c5027d711c
@ -7,7 +7,7 @@ import json
|
||||
import re
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import suppress
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
@ -138,7 +138,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the HuggingFaceLocalChatGenerator component.
|
||||
|
||||
@ -249,14 +249,14 @@ class HuggingFaceLocalChatGenerator:
|
||||
else async_executor
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Cleanup when the instance is being destroyed.
|
||||
"""
|
||||
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
|
||||
if self._owns_executor:
|
||||
self.executor.shutdown(wait=True)
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Explicitly shutdown the executor if we own it.
|
||||
"""
|
||||
@ -271,7 +271,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||
|
||||
def warm_up(self):
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
@ -336,105 +336,41 @@ class HuggingFaceLocalChatGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
"""
|
||||
Invoke text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
:param messages: A list of ChatMessage objects representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
:param tools:
|
||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||
:param tools: A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||
the `tools` parameter provided during initialization. This parameter can accept either a list
|
||||
of `Tool` objects or a `Toolset` instance.
|
||||
:returns:
|
||||
A list containing the generated responses as ChatMessage instances.
|
||||
:returns: A dictionary with the following keys:
|
||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
tools = tools or self.tools
|
||||
if tools and streaming_callback is not None:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
|
||||
tokenizer = self.pipeline.tokenizer
|
||||
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
|
||||
assert tokenizer is not None
|
||||
|
||||
# Check and update generation parameters
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
||||
# pipeline call doesn't support stop_sequences, so we need to pop it
|
||||
stop_words = self._validate_stop_words(stop_words)
|
||||
|
||||
# Set up stop words criteria if stop words exist
|
||||
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
|
||||
if stop_words_criteria:
|
||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
||||
prepared_inputs = self._prepare_inputs(
|
||||
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
|
||||
)
|
||||
|
||||
streaming_callback = select_streaming_callback(
|
||||
self.streaming_callback, streaming_callback, requires_async=False
|
||||
)
|
||||
if streaming_callback:
|
||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
msg = (
|
||||
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||
"Streaming is only supported for single response generation. "
|
||||
"Setting the number of responses to 1."
|
||||
)
|
||||
logger.warning(msg, num_responses=num_responses)
|
||||
generation_kwargs["num_return_sequences"] = 1
|
||||
|
||||
# Get component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||
tokenizer=tokenizer,
|
||||
prepared_inputs["generation_kwargs"]["streamer"] = HFTokenStreamingHandler(
|
||||
tokenizer=prepared_inputs["tokenizer"],
|
||||
stream_handler=streaming_callback,
|
||||
stop_words=stop_words,
|
||||
component_info=component_info,
|
||||
stop_words=prepared_inputs["stop_words"],
|
||||
component_info=ComponentInfo.from_component(self),
|
||||
)
|
||||
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages,
|
||||
tokenize=False,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=True,
|
||||
tools=[tc.tool_spec for tc in tools] if tools else None,
|
||||
)
|
||||
|
||||
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
||||
assert isinstance(prepared_prompt, str)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
generation_kwargs["pad_token_id"] = (
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# We know it's not None because we check it in _prepare_inputs
|
||||
assert self.pipeline is not None
|
||||
# Generate responses
|
||||
output = self.pipeline(prepared_prompt, **generation_kwargs)
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
|
||||
|
||||
# Remove stop words from replies if present
|
||||
if stop_words:
|
||||
for stop_word in stop_words:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
||||
|
||||
chat_messages = [
|
||||
self.create_message(
|
||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
|
||||
)
|
||||
for r_index, reply in enumerate(replies)
|
||||
]
|
||||
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
|
||||
|
||||
return {"replies": chat_messages}
|
||||
|
||||
@ -512,7 +448,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
"""
|
||||
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
@ -527,6 +463,77 @@ class HuggingFaceLocalChatGenerator:
|
||||
:returns: A dictionary with the following keys:
|
||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
prepared_inputs = self._prepare_inputs(
|
||||
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
|
||||
)
|
||||
|
||||
# Validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
||||
|
||||
if streaming_callback:
|
||||
async_handler = AsyncHFTokenStreamingHandler(
|
||||
tokenizer=prepared_inputs["tokenizer"],
|
||||
# Cast to AsyncStreamingCallbackT since we know streaming_callback is async
|
||||
stream_handler=cast(AsyncStreamingCallbackT, streaming_callback),
|
||||
stop_words=prepared_inputs["stop_words"],
|
||||
component_info=ComponentInfo.from_component(self),
|
||||
)
|
||||
prepared_inputs["generation_kwargs"]["streamer"] = async_handler
|
||||
|
||||
# Use async context manager for proper resource cleanup
|
||||
async with self._manage_queue_processor(async_handler):
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor,
|
||||
# have to ignore since assert self.pipeline is not None doesn't work
|
||||
lambda: self.pipeline( # type: ignore[misc]
|
||||
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor,
|
||||
# have to ignore since assert self.pipeline is not None doesn't work
|
||||
lambda: self.pipeline( # type: ignore[misc]
|
||||
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
|
||||
),
|
||||
)
|
||||
|
||||
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
|
||||
return {"replies": chat_messages}
|
||||
|
||||
@asynccontextmanager
|
||||
async def _manage_queue_processor(self, async_handler: "AsyncHFTokenStreamingHandler"):
|
||||
"""Context manager for proper queue processor lifecycle management."""
|
||||
queue_processor = asyncio.create_task(async_handler.process_queue())
|
||||
try:
|
||||
yield queue_processor
|
||||
finally:
|
||||
# Ensure the queue processor is cleaned up properly
|
||||
try:
|
||||
await asyncio.wait_for(queue_processor, timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
queue_processor.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await queue_processor
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepares the inputs for the Hugging Face pipeline.
|
||||
|
||||
:param messages: A list of ChatMessage objects representing the input messages.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||
:returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
|
||||
:raises RuntimeError: If the generation model has not been loaded.
|
||||
:raises ValueError: If both tools and streaming_callback are provided.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
@ -535,6 +542,9 @@ class HuggingFaceLocalChatGenerator:
|
||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
|
||||
tokenizer = self.pipeline.tokenizer
|
||||
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
|
||||
assert tokenizer is not None
|
||||
@ -542,6 +552,18 @@ class HuggingFaceLocalChatGenerator:
|
||||
# Check and update generation parameters
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# If streaming_callback is provided, ensure that num_return_sequences is set to 1
|
||||
if streaming_callback:
|
||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
msg = (
|
||||
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||
"Streaming is only supported for single response generation. "
|
||||
"Setting the number of responses to 1."
|
||||
)
|
||||
logger.warning(msg, num_responses=num_responses)
|
||||
generation_kwargs["num_return_sequences"] = 1
|
||||
|
||||
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
||||
stop_words = self._validate_stop_words(stop_words)
|
||||
|
||||
@ -550,99 +572,8 @@ class HuggingFaceLocalChatGenerator:
|
||||
if stop_words_criteria:
|
||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
||||
|
||||
# validate and select the streaming callback
|
||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
||||
|
||||
if streaming_callback:
|
||||
return await self._run_streaming_async(
|
||||
messages, tokenizer, generation_kwargs, stop_words, streaming_callback
|
||||
)
|
||||
|
||||
return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
|
||||
|
||||
async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
streaming_callback: AsyncStreamingCallbackT,
|
||||
):
|
||||
"""
|
||||
Handles async streaming generation of responses.
|
||||
"""
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
||||
assert isinstance(prepared_prompt, str)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
generation_kwargs["pad_token_id"] = (
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# get the component name and type
|
||||
component_info = ComponentInfo.from_component(self)
|
||||
|
||||
async_handler = AsyncHFTokenStreamingHandler(
|
||||
tokenizer=tokenizer, stream_handler=streaming_callback, stop_words=stop_words, component_info=component_info
|
||||
)
|
||||
generation_kwargs["streamer"] = async_handler
|
||||
|
||||
# Start queue processing in the background
|
||||
queue_processor = asyncio.create_task(async_handler.process_queue())
|
||||
|
||||
try:
|
||||
# Generate responses asynchronously
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
|
||||
)
|
||||
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
|
||||
# Remove stop words from replies if present
|
||||
if stop_words:
|
||||
for stop_word in stop_words:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
||||
|
||||
chat_messages = [
|
||||
self.create_message(
|
||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False
|
||||
)
|
||||
for r_index, reply in enumerate(replies)
|
||||
]
|
||||
|
||||
return {"replies": chat_messages}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await asyncio.wait_for(queue_processor, timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
queue_processor.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await queue_processor
|
||||
|
||||
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
):
|
||||
"""
|
||||
Handles async non-streaming generation of responses.
|
||||
"""
|
||||
# convert messages to HF format
|
||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
|
||||
prepared_prompt = tokenizer.apply_chat_template(
|
||||
hf_messages,
|
||||
@ -651,22 +582,44 @@ class HuggingFaceLocalChatGenerator:
|
||||
add_generation_prompt=True,
|
||||
tools=[tc.tool_spec for tc in tools] if tools else None,
|
||||
)
|
||||
|
||||
# prepared_prompt is a string, but transformers has some type issues
|
||||
prepared_prompt = cast(str, prepared_prompt)
|
||||
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
||||
assert isinstance(prepared_prompt, str)
|
||||
|
||||
# Avoid some unnecessary warnings in the generation pipeline call
|
||||
generation_kwargs["pad_token_id"] = (
|
||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Generate responses asynchronously
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
|
||||
)
|
||||
return {
|
||||
"prepared_prompt": prepared_prompt,
|
||||
"tokenizer": tokenizer,
|
||||
"generation_kwargs": generation_kwargs,
|
||||
"tools": tools,
|
||||
"stop_words": stop_words,
|
||||
}
|
||||
|
||||
replies = [o.get("generated_text", "") for o in output]
|
||||
def _convert_hf_output_to_chat_messages(
|
||||
self,
|
||||
*,
|
||||
hf_pipeline_output: List[Dict[str, Any]],
|
||||
prepared_prompt: str,
|
||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||
generation_kwargs: Dict[str, Any],
|
||||
stop_words: Optional[List[str]],
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
) -> List[ChatMessage]:
|
||||
"""
|
||||
Converts the HuggingFace pipeline output into a List of ChatMessages
|
||||
|
||||
:param hf_pipeline_output: The output from the HuggingFace pipeline.
|
||||
:param prepared_prompt: The prompt used for generation.
|
||||
:param tokenizer: The tokenizer used for generation.
|
||||
:param generation_kwargs: The generation parameters.
|
||||
:param stop_words: A list of stop words to remove from the replies.
|
||||
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||
"""
|
||||
replies = [o.get("generated_text", "") for o in hf_pipeline_output]
|
||||
|
||||
# Remove stop words from replies if present
|
||||
if stop_words:
|
||||
@ -675,9 +628,13 @@ class HuggingFaceLocalChatGenerator:
|
||||
|
||||
chat_messages = [
|
||||
self.create_message(
|
||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
|
||||
text=reply,
|
||||
index=r_index,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prepared_prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
parse_tool_calls=bool(tools),
|
||||
)
|
||||
for r_index, reply in enumerate(replies)
|
||||
]
|
||||
|
||||
return {"replies": chat_messages}
|
||||
return chat_messages
|
||||
|
Loading…
x
Reference in New Issue
Block a user