mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
refactor: Refactor HuggingFaceLocalChatGenerator
(#9455)
* Refactoring to better align run and run_async and reduce duplicate code * Docstrings and align run and run_async * More changes * add missing type * Refactor async part a bit * Fix import error * Fix mypy
This commit is contained in:
parent
379df4ab84
commit
c5027d711c
@ -7,7 +7,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import suppress
|
from contextlib import asynccontextmanager, suppress
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
@ -138,7 +138,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
|
||||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the HuggingFaceLocalChatGenerator component.
|
Initializes the HuggingFaceLocalChatGenerator component.
|
||||||
|
|
||||||
@ -249,14 +249,14 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
else async_executor
|
else async_executor
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
"""
|
"""
|
||||||
Cleanup when the instance is being destroyed.
|
Cleanup when the instance is being destroyed.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
|
if self._owns_executor:
|
||||||
self.executor.shutdown(wait=True)
|
self.executor.shutdown(wait=True)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self) -> None:
|
||||||
"""
|
"""
|
||||||
Explicitly shutdown the executor if we own it.
|
Explicitly shutdown the executor if we own it.
|
||||||
"""
|
"""
|
||||||
@ -271,7 +271,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||||
|
|
||||||
def warm_up(self):
|
def warm_up(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the component.
|
Initializes the component.
|
||||||
"""
|
"""
|
||||||
@ -336,105 +336,41 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
) -> Dict[str, List[ChatMessage]]:
|
||||||
"""
|
"""
|
||||||
Invoke text generation inference based on the provided messages and generation parameters.
|
Invoke text generation inference based on the provided messages and generation parameters.
|
||||||
|
|
||||||
:param messages: A list of ChatMessage objects representing the input messages.
|
:param messages: A list of ChatMessage objects representing the input messages.
|
||||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||||
:param streaming_callback: An optional callable for handling streaming responses.
|
:param streaming_callback: An optional callable for handling streaming responses.
|
||||||
:param tools:
|
:param tools: A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
||||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
|
||||||
the `tools` parameter provided during initialization. This parameter can accept either a list
|
the `tools` parameter provided during initialization. This parameter can accept either a list
|
||||||
of `Tool` objects or a `Toolset` instance.
|
of `Tool` objects or a `Toolset` instance.
|
||||||
:returns:
|
:returns: A dictionary with the following keys:
|
||||||
A list containing the generated responses as ChatMessage instances.
|
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||||
"""
|
"""
|
||||||
if self.pipeline is None:
|
prepared_inputs = self._prepare_inputs(
|
||||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
|
||||||
|
)
|
||||||
tools = tools or self.tools
|
|
||||||
if tools and streaming_callback is not None:
|
|
||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
|
||||||
_check_duplicate_tool_names(list(tools or []))
|
|
||||||
|
|
||||||
tokenizer = self.pipeline.tokenizer
|
|
||||||
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
|
|
||||||
assert tokenizer is not None
|
|
||||||
|
|
||||||
# Check and update generation parameters
|
|
||||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
|
||||||
|
|
||||||
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
|
||||||
# pipeline call doesn't support stop_sequences, so we need to pop it
|
|
||||||
stop_words = self._validate_stop_words(stop_words)
|
|
||||||
|
|
||||||
# Set up stop words criteria if stop words exist
|
|
||||||
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
|
|
||||||
if stop_words_criteria:
|
|
||||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
|
||||||
|
|
||||||
streaming_callback = select_streaming_callback(
|
streaming_callback = select_streaming_callback(
|
||||||
self.streaming_callback, streaming_callback, requires_async=False
|
self.streaming_callback, streaming_callback, requires_async=False
|
||||||
)
|
)
|
||||||
if streaming_callback:
|
if streaming_callback:
|
||||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
|
||||||
if num_responses > 1:
|
|
||||||
msg = (
|
|
||||||
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
|
||||||
"Streaming is only supported for single response generation. "
|
|
||||||
"Setting the number of responses to 1."
|
|
||||||
)
|
|
||||||
logger.warning(msg, num_responses=num_responses)
|
|
||||||
generation_kwargs["num_return_sequences"] = 1
|
|
||||||
|
|
||||||
# Get component name and type
|
|
||||||
component_info = ComponentInfo.from_component(self)
|
|
||||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
prepared_inputs["generation_kwargs"]["streamer"] = HFTokenStreamingHandler(
|
||||||
tokenizer=tokenizer,
|
tokenizer=prepared_inputs["tokenizer"],
|
||||||
stream_handler=streaming_callback,
|
stream_handler=streaming_callback,
|
||||||
stop_words=stop_words,
|
stop_words=prepared_inputs["stop_words"],
|
||||||
component_info=component_info,
|
component_info=ComponentInfo.from_component(self),
|
||||||
)
|
|
||||||
|
|
||||||
# convert messages to HF format
|
|
||||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
|
||||||
|
|
||||||
if isinstance(tools, Toolset):
|
|
||||||
tools = list(tools)
|
|
||||||
|
|
||||||
prepared_prompt = tokenizer.apply_chat_template(
|
|
||||||
hf_messages,
|
|
||||||
tokenize=False,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tools=[tc.tool_spec for tc in tools] if tools else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
|
||||||
assert isinstance(prepared_prompt, str)
|
|
||||||
|
|
||||||
# Avoid some unnecessary warnings in the generation pipeline call
|
|
||||||
generation_kwargs["pad_token_id"] = (
|
|
||||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We know it's not None because we check it in _prepare_inputs
|
||||||
|
assert self.pipeline is not None
|
||||||
# Generate responses
|
# Generate responses
|
||||||
output = self.pipeline(prepared_prompt, **generation_kwargs)
|
output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
|
||||||
replies = [o.get("generated_text", "") for o in output]
|
|
||||||
|
|
||||||
# Remove stop words from replies if present
|
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
|
||||||
if stop_words:
|
|
||||||
for stop_word in stop_words:
|
|
||||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
|
||||||
|
|
||||||
chat_messages = [
|
|
||||||
self.create_message(
|
|
||||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
|
|
||||||
)
|
|
||||||
for r_index, reply in enumerate(replies)
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"replies": chat_messages}
|
return {"replies": chat_messages}
|
||||||
|
|
||||||
@ -512,7 +448,7 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
):
|
) -> Dict[str, List[ChatMessage]]:
|
||||||
"""
|
"""
|
||||||
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
||||||
|
|
||||||
@ -527,6 +463,77 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
:returns: A dictionary with the following keys:
|
:returns: A dictionary with the following keys:
|
||||||
- `replies`: A list containing the generated responses as ChatMessage instances.
|
- `replies`: A list containing the generated responses as ChatMessage instances.
|
||||||
"""
|
"""
|
||||||
|
prepared_inputs = self._prepare_inputs(
|
||||||
|
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate and select the streaming callback
|
||||||
|
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
||||||
|
|
||||||
|
if streaming_callback:
|
||||||
|
async_handler = AsyncHFTokenStreamingHandler(
|
||||||
|
tokenizer=prepared_inputs["tokenizer"],
|
||||||
|
# Cast to AsyncStreamingCallbackT since we know streaming_callback is async
|
||||||
|
stream_handler=cast(AsyncStreamingCallbackT, streaming_callback),
|
||||||
|
stop_words=prepared_inputs["stop_words"],
|
||||||
|
component_info=ComponentInfo.from_component(self),
|
||||||
|
)
|
||||||
|
prepared_inputs["generation_kwargs"]["streamer"] = async_handler
|
||||||
|
|
||||||
|
# Use async context manager for proper resource cleanup
|
||||||
|
async with self._manage_queue_processor(async_handler):
|
||||||
|
output = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
# have to ignore since assert self.pipeline is not None doesn't work
|
||||||
|
lambda: self.pipeline( # type: ignore[misc]
|
||||||
|
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
# have to ignore since assert self.pipeline is not None doesn't work
|
||||||
|
lambda: self.pipeline( # type: ignore[misc]
|
||||||
|
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
|
||||||
|
return {"replies": chat_messages}
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _manage_queue_processor(self, async_handler: "AsyncHFTokenStreamingHandler"):
|
||||||
|
"""Context manager for proper queue processor lifecycle management."""
|
||||||
|
queue_processor = asyncio.create_task(async_handler.process_queue())
|
||||||
|
try:
|
||||||
|
yield queue_processor
|
||||||
|
finally:
|
||||||
|
# Ensure the queue processor is cleaned up properly
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(queue_processor, timeout=0.1)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
queue_processor.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await queue_processor
|
||||||
|
|
||||||
|
def _prepare_inputs(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Prepares the inputs for the Hugging Face pipeline.
|
||||||
|
|
||||||
|
:param messages: A list of ChatMessage objects representing the input messages.
|
||||||
|
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||||
|
:param streaming_callback: An optional callable for handling streaming responses.
|
||||||
|
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||||
|
:returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
|
||||||
|
:raises RuntimeError: If the generation model has not been loaded.
|
||||||
|
:raises ValueError: If both tools and streaming_callback are provided.
|
||||||
|
"""
|
||||||
if self.pipeline is None:
|
if self.pipeline is None:
|
||||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||||
|
|
||||||
@ -535,6 +542,9 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
|
||||||
_check_duplicate_tool_names(list(tools or []))
|
_check_duplicate_tool_names(list(tools or []))
|
||||||
|
|
||||||
|
if isinstance(tools, Toolset):
|
||||||
|
tools = list(tools)
|
||||||
|
|
||||||
tokenizer = self.pipeline.tokenizer
|
tokenizer = self.pipeline.tokenizer
|
||||||
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
|
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
|
||||||
assert tokenizer is not None
|
assert tokenizer is not None
|
||||||
@ -542,6 +552,18 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
# Check and update generation parameters
|
# Check and update generation parameters
|
||||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
|
# If streaming_callback is provided, ensure that num_return_sequences is set to 1
|
||||||
|
if streaming_callback:
|
||||||
|
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||||
|
if num_responses > 1:
|
||||||
|
msg = (
|
||||||
|
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||||
|
"Streaming is only supported for single response generation. "
|
||||||
|
"Setting the number of responses to 1."
|
||||||
|
)
|
||||||
|
logger.warning(msg, num_responses=num_responses)
|
||||||
|
generation_kwargs["num_return_sequences"] = 1
|
||||||
|
|
||||||
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
||||||
stop_words = self._validate_stop_words(stop_words)
|
stop_words = self._validate_stop_words(stop_words)
|
||||||
|
|
||||||
@ -550,99 +572,8 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
if stop_words_criteria:
|
if stop_words_criteria:
|
||||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
||||||
|
|
||||||
# validate and select the streaming callback
|
|
||||||
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
|
||||||
|
|
||||||
if streaming_callback:
|
|
||||||
return await self._run_streaming_async(
|
|
||||||
messages, tokenizer, generation_kwargs, stop_words, streaming_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
|
|
||||||
|
|
||||||
async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
|
|
||||||
self,
|
|
||||||
messages: List[ChatMessage],
|
|
||||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
|
||||||
generation_kwargs: Dict[str, Any],
|
|
||||||
stop_words: Optional[List[str]],
|
|
||||||
streaming_callback: AsyncStreamingCallbackT,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Handles async streaming generation of responses.
|
|
||||||
"""
|
|
||||||
# convert messages to HF format
|
# convert messages to HF format
|
||||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
||||||
prepared_prompt = tokenizer.apply_chat_template(
|
|
||||||
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
|
||||||
assert isinstance(prepared_prompt, str)
|
|
||||||
|
|
||||||
# Avoid some unnecessary warnings in the generation pipeline call
|
|
||||||
generation_kwargs["pad_token_id"] = (
|
|
||||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the component name and type
|
|
||||||
component_info = ComponentInfo.from_component(self)
|
|
||||||
|
|
||||||
async_handler = AsyncHFTokenStreamingHandler(
|
|
||||||
tokenizer=tokenizer, stream_handler=streaming_callback, stop_words=stop_words, component_info=component_info
|
|
||||||
)
|
|
||||||
generation_kwargs["streamer"] = async_handler
|
|
||||||
|
|
||||||
# Start queue processing in the background
|
|
||||||
queue_processor = asyncio.create_task(async_handler.process_queue())
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Generate responses asynchronously
|
|
||||||
output = await asyncio.get_running_loop().run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
|
|
||||||
)
|
|
||||||
|
|
||||||
replies = [o.get("generated_text", "") for o in output]
|
|
||||||
|
|
||||||
# Remove stop words from replies if present
|
|
||||||
if stop_words:
|
|
||||||
for stop_word in stop_words:
|
|
||||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
|
||||||
|
|
||||||
chat_messages = [
|
|
||||||
self.create_message(
|
|
||||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False
|
|
||||||
)
|
|
||||||
for r_index, reply in enumerate(replies)
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"replies": chat_messages}
|
|
||||||
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(queue_processor, timeout=0.1)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
queue_processor.cancel()
|
|
||||||
with suppress(asyncio.CancelledError):
|
|
||||||
await queue_processor
|
|
||||||
|
|
||||||
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
|
|
||||||
self,
|
|
||||||
messages: List[ChatMessage],
|
|
||||||
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
|
||||||
generation_kwargs: Dict[str, Any],
|
|
||||||
stop_words: Optional[List[str]],
|
|
||||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Handles async non-streaming generation of responses.
|
|
||||||
"""
|
|
||||||
# convert messages to HF format
|
|
||||||
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
|
||||||
|
|
||||||
if isinstance(tools, Toolset):
|
|
||||||
tools = list(tools)
|
|
||||||
|
|
||||||
prepared_prompt = tokenizer.apply_chat_template(
|
prepared_prompt = tokenizer.apply_chat_template(
|
||||||
hf_messages,
|
hf_messages,
|
||||||
@ -651,22 +582,44 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tools=[tc.tool_spec for tc in tools] if tools else None,
|
tools=[tc.tool_spec for tc in tools] if tools else None,
|
||||||
)
|
)
|
||||||
|
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
||||||
# prepared_prompt is a string, but transformers has some type issues
|
assert isinstance(prepared_prompt, str)
|
||||||
prepared_prompt = cast(str, prepared_prompt)
|
|
||||||
|
|
||||||
# Avoid some unnecessary warnings in the generation pipeline call
|
# Avoid some unnecessary warnings in the generation pipeline call
|
||||||
generation_kwargs["pad_token_id"] = (
|
generation_kwargs["pad_token_id"] = (
|
||||||
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate responses asynchronously
|
return {
|
||||||
output = await asyncio.get_running_loop().run_in_executor(
|
"prepared_prompt": prepared_prompt,
|
||||||
self.executor,
|
"tokenizer": tokenizer,
|
||||||
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
|
"generation_kwargs": generation_kwargs,
|
||||||
)
|
"tools": tools,
|
||||||
|
"stop_words": stop_words,
|
||||||
|
}
|
||||||
|
|
||||||
replies = [o.get("generated_text", "") for o in output]
|
def _convert_hf_output_to_chat_messages(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
hf_pipeline_output: List[Dict[str, Any]],
|
||||||
|
prepared_prompt: str,
|
||||||
|
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
||||||
|
generation_kwargs: Dict[str, Any],
|
||||||
|
stop_words: Optional[List[str]],
|
||||||
|
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||||
|
) -> List[ChatMessage]:
|
||||||
|
"""
|
||||||
|
Converts the HuggingFace pipeline output into a List of ChatMessages
|
||||||
|
|
||||||
|
:param hf_pipeline_output: The output from the HuggingFace pipeline.
|
||||||
|
:param prepared_prompt: The prompt used for generation.
|
||||||
|
:param tokenizer: The tokenizer used for generation.
|
||||||
|
:param generation_kwargs: The generation parameters.
|
||||||
|
:param stop_words: A list of stop words to remove from the replies.
|
||||||
|
:param tools: A list of tools or a Toolset for which the model can prepare calls.
|
||||||
|
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
||||||
|
"""
|
||||||
|
replies = [o.get("generated_text", "") for o in hf_pipeline_output]
|
||||||
|
|
||||||
# Remove stop words from replies if present
|
# Remove stop words from replies if present
|
||||||
if stop_words:
|
if stop_words:
|
||||||
@ -675,9 +628,13 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
|
|
||||||
chat_messages = [
|
chat_messages = [
|
||||||
self.create_message(
|
self.create_message(
|
||||||
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
|
text=reply,
|
||||||
|
index=r_index,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=prepared_prompt,
|
||||||
|
generation_kwargs=generation_kwargs,
|
||||||
|
parse_tool_calls=bool(tools),
|
||||||
)
|
)
|
||||||
for r_index, reply in enumerate(replies)
|
for r_index, reply in enumerate(replies)
|
||||||
]
|
]
|
||||||
|
return chat_messages
|
||||||
return {"replies": chat_messages}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user