mv StopWordsCriteria under lazy_import (#6128)

This commit is contained in:
Stefano Fiorucci 2023-10-19 17:48:59 +02:00 committed by GitHub
parent 025418c10e
commit fe261b9986
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,15 @@ from copy import deepcopy
from haystack.preview import component, default_to_dict
from haystack.preview.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
with LazyImport(
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
) as torch_import:
import torch
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import model_info
from transformers import (
@ -15,15 +24,41 @@ with LazyImport(message="Run 'pip install transformers'") as transformers_import
PreTrainedTokenizerFast,
)
with LazyImport(
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
) as torch_import:
import torch
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
logger = logging.getLogger(__name__)
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result
@component
@ -186,40 +221,3 @@ class HuggingFaceLocalGenerator:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
return {"replies": replies}
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result