mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-01 12:23:31 +00:00
mv StopWordsCriteria under lazy_import (#6128)
This commit is contained in:
parent
025418c10e
commit
fe261b9986
@ -5,6 +5,15 @@ from copy import deepcopy
|
||||
from haystack.preview import component, default_to_dict
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
|
||||
with LazyImport(
|
||||
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
|
||||
) as torch_import:
|
||||
import torch
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import model_info
|
||||
from transformers import (
|
||||
@ -15,15 +24,41 @@ with LazyImport(message="Run 'pip install transformers'") as transformers_import
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
with LazyImport(
|
||||
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
|
||||
) as torch_import:
|
||||
import torch
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||
prematurely after the first token. This is particularly important for LLMs designed
|
||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||
the output includes both the new text and the original prompt. Therefore, it's important
|
||||
to make sure your prompt has no stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stop_words: List[str],
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
for stop_id in self.stop_ids:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
|
||||
|
||||
@component
|
||||
@ -186,40 +221,3 @@ class HuggingFaceLocalGenerator:
|
||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
|
||||
|
||||
return {"replies": replies}
|
||||
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||
prematurely after the first token. This is particularly important for LLMs designed
|
||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||
the output includes both the new text and the original prompt. Therefore, it's important
|
||||
to make sure your prompt has no stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stop_words: List[str],
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
for stop_id in self.stop_ids:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
|
Loading…
x
Reference in New Issue
Block a user