refactor: Extract HF stop words handling in hf_utils.py (#6745)

* Move StopWordsCriteria to hf_utils.py

* Raise ValueError for invalid StopWordsCriteria tokenizer

* StopWordsCriteria, make sure padding token exists

* Use proper torch types

* Update unit tests
This commit is contained in:
Vladimir Blagojevic 2024-01-15 17:42:29 +01:00 committed by GitHub
parent 96c0b59aaa
commit 8cafff0645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 45 deletions

View File

@ -1,5 +1,5 @@
import inspect
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from haystack.lazy_imports import LazyImport
@ -55,3 +55,55 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None:
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
if not allowed_model:
raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.")
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
import torch
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
# check if tokenizer is a valid tokenizer
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
raise ValueError(
f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
)
if not tokenizer.pad_token:
if tokenizer.eos_token:
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result

View File

@ -2,6 +2,7 @@ import logging
from typing import Any, Dict, List, Literal, Optional, Union
from haystack import component, default_to_dict, default_from_dict
from haystack.components.generators.hf_utils import StopWordsCriteria
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
@ -11,49 +12,7 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
import torch
from huggingface_hub import model_info
from transformers import (
pipeline,
StoppingCriteriaList,
StoppingCriteria,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, "torch.device"] = "cpu",
):
super().__init__()
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)
def __call__(self, input_ids: "torch.LongTensor", scores: "torch.FloatTensor", **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: "torch.Tensor", stop_id: "torch.Tensor") -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result
from transformers import StoppingCriteriaList, pipeline
@component

View File

@ -3,6 +3,7 @@ from unittest.mock import patch, Mock
import pytest
import torch
from transformers import PreTrainedTokenizerFast
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
@ -362,7 +363,7 @@ class TestHuggingFaceLocalGenerator:
# "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of
# "unambiguously" in input_ids1 which is not correct.
stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"])
stop_words_criteria = StopWordsCriteria(tokenizer=Mock(spec=PreTrainedTokenizerFast), stop_words=["mock data"])
# because we are mocking the tokenizer, we need to set the stop words manually
stop_words_criteria.stop_ids = stop_words_id