mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 22:04:01 +00:00
mv StopWordsCriteria under lazy_import (#6128)
This commit is contained in:
parent
025418c10e
commit
fe261b9986
@ -5,6 +5,15 @@ from copy import deepcopy
|
|||||||
from haystack.preview import component, default_to_dict
|
from haystack.preview import component, default_to_dict
|
||||||
from haystack.preview.lazy_imports import LazyImport
|
from haystack.preview.lazy_imports import LazyImport
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||||
|
|
||||||
|
with LazyImport(
|
||||||
|
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
|
||||||
|
) as torch_import:
|
||||||
|
import torch
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||||
from huggingface_hub import model_info
|
from huggingface_hub import model_info
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -15,15 +24,41 @@ with LazyImport(message="Run 'pip install transformers'") as transformers_import
|
|||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
)
|
)
|
||||||
|
|
||||||
with LazyImport(
|
class StopWordsCriteria(StoppingCriteria):
|
||||||
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
|
"""
|
||||||
) as torch_import:
|
Stops text generation if any one of the stop words is generated.
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||||
|
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||||
|
prematurely after the first token. This is particularly important for LLMs designed
|
||||||
|
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||||
|
the output includes both the new text and the original prompt. Therefore, it's important
|
||||||
|
to make sure your prompt has no stop words.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
stop_words: List[str],
|
||||||
|
device: Union[str, torch.device] = "cpu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||||
|
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||||
|
|
||||||
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
for stop_id in self.stop_ids:
|
||||||
|
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||||
|
if found_stop_word:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
||||||
|
generated_text_ids = generated_text_ids[-1]
|
||||||
|
len_generated_text_ids = generated_text_ids.size(0)
|
||||||
|
len_stop_id = stop_id.size(0)
|
||||||
|
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
@ -186,40 +221,3 @@ class HuggingFaceLocalGenerator:
|
|||||||
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
|
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
|
||||||
|
|
||||||
return {"replies": replies}
|
return {"replies": replies}
|
||||||
|
|
||||||
|
|
||||||
class StopWordsCriteria(StoppingCriteria):
|
|
||||||
"""
|
|
||||||
Stops text generation if any one of the stop words is generated.
|
|
||||||
|
|
||||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
|
||||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
|
||||||
prematurely after the first token. This is particularly important for LLMs designed
|
|
||||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
|
||||||
the output includes both the new text and the original prompt. Therefore, it's important
|
|
||||||
to make sure your prompt has no stop words.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
||||||
stop_words: List[str],
|
|
||||||
device: Union[str, torch.device] = "cpu",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
|
||||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
||||||
for stop_id in self.stop_ids:
|
|
||||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
|
||||||
if found_stop_word:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
|
||||||
generated_text_ids = generated_text_ids[-1]
|
|
||||||
len_generated_text_ids = generated_text_ids.size(0)
|
|
||||||
len_stop_id = stop_id.size(0)
|
|
||||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
|
||||||
return result
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user