mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 08:37:20 +00:00
refactor: Extract HF stop words handling in hf_utils.py (#6745)
* Move StopWordsCriteria to hf_utils.py * Raise ValueError for invalid StopWordsCriteria tokenizer * StopWordsCriteria, make sure padding token exists * Use proper torch types * Update unit tests
This commit is contained in:
parent
96c0b59aaa
commit
8cafff0645
@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
@ -55,3 +55,55 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None:
|
||||
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
|
||||
if not allowed_model:
|
||||
raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.")
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
import torch
|
||||
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||
prematurely after the first token. This is particularly important for LLMs designed
|
||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||
the output includes both the new text and the original prompt. Therefore, it's important
|
||||
to make sure your prompt has no stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stop_words: List[str],
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
# check if tokenizer is a valid tokenizer
|
||||
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
|
||||
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
|
||||
)
|
||||
if not tokenizer.pad_token:
|
||||
if tokenizer.eos_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
for stop_id in self.stop_ids:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.generators.hf_utils import StopWordsCriteria
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -11,49 +12,7 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
import torch
|
||||
from huggingface_hub import model_info
|
||||
from transformers import (
|
||||
pipeline,
|
||||
StoppingCriteriaList,
|
||||
StoppingCriteria,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||
prematurely after the first token. This is particularly important for LLMs designed
|
||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||
the output includes both the new text and the original prompt. Therefore, it's important
|
||||
to make sure your prompt has no stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stop_words: List[str],
|
||||
device: Union[str, "torch.device"] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
def __call__(self, input_ids: "torch.LongTensor", scores: "torch.FloatTensor", **kwargs) -> bool:
|
||||
for stop_id in self.stop_ids:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: "torch.Tensor", stop_id: "torch.Tensor") -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
from transformers import StoppingCriteriaList, pipeline
|
||||
|
||||
|
||||
@component
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
|
||||
|
||||
@ -362,7 +363,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
# "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of
|
||||
# "unambiguously" in input_ids1 which is not correct.
|
||||
|
||||
stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"])
|
||||
stop_words_criteria = StopWordsCriteria(tokenizer=Mock(spec=PreTrainedTokenizerFast), stop_words=["mock data"])
|
||||
# because we are mocking the tokenizer, we need to set the stop words manually
|
||||
stop_words_criteria.stop_ids = stop_words_id
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user