From 8cafff0645c45420f32b9894de778452b295b7fd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 15 Jan 2024 17:42:29 +0100 Subject: [PATCH] refactor: Extract HF stop words handling in `hf_utils.py` (#6745) * Move StopWordsCriteria to hf_utils.py * Raise ValueError for invalid StopWordsCriteria tokenizer * StopWordsCriteria, make sure padding token exists * Use proper torch types * Update unit tests --- haystack/components/generators/hf_utils.py | 54 ++++++++++++++++++- .../generators/hugging_face_local.py | 45 +--------------- .../test_hugging_face_local_generator.py | 3 +- 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/haystack/components/generators/hf_utils.py b/haystack/components/generators/hf_utils.py index 832a99628..93dd2d750 100644 --- a/haystack/components/generators/hf_utils.py +++ b/haystack/components/generators/hf_utils.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack.lazy_imports import LazyImport @@ -55,3 +55,55 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None: allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"] if not allowed_model: raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.") + + +with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: + import torch + from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast + + class StopWordsCriteria(StoppingCriteria): + """ + Stops text generation if any one of the stop words is generated. + + Note: When a stop word is encountered, the generation of new text is stopped. + However, if the stop word is in the prompt itself, it can stop generating new text + prematurely after the first token. This is particularly important for LLMs designed + for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat, + the output includes both the new text and the original prompt. Therefore, it's important + to make sure your prompt has no stop words. + """ + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + stop_words: List[str], + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + # check if tokenizer is a valid tokenizer + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. " + f"Please provide a valid tokenizer from the HuggingFace Transformers library." + ) + if not tokenizer.pad_token: + if tokenizer.eos_token: + tokenizer.pad_token = tokenizer.eos_token + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt") + self.stop_ids = encoded_stop_words.input_ids.to(device) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + for stop_id in self.stop_ids: + found_stop_word = self.is_stop_word_found(input_ids, stop_id) + if found_stop_word: + return True + return False + + def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool: + generated_text_ids = generated_text_ids[-1] + len_generated_text_ids = generated_text_ids.size(0) + len_stop_id = stop_id.size(0) + result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id)) + return result diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 765f79eeb..6c2fc3e1f 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict, List, Literal, Optional, Union from haystack import component, default_to_dict, default_from_dict +from haystack.components.generators.hf_utils import StopWordsCriteria from haystack.lazy_imports import LazyImport logger = logging.getLogger(__name__) @@ -11,49 +12,7 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"] with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: import torch from huggingface_hub import model_info - from transformers import ( - pipeline, - StoppingCriteriaList, - StoppingCriteria, - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ) - - class StopWordsCriteria(StoppingCriteria): - """ - Stops text generation if any one of the stop words is generated. - - Note: When a stop word is encountered, the generation of new text is stopped. - However, if the stop word is in the prompt itself, it can stop generating new text - prematurely after the first token. This is particularly important for LLMs designed - for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat, - the output includes both the new text and the original prompt. Therefore, it's important - to make sure your prompt has no stop words. - """ - - def __init__( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - stop_words: List[str], - device: Union[str, "torch.device"] = "cpu", - ): - super().__init__() - encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt") - self.stop_ids = encoded_stop_words.input_ids.to(device) - - def __call__(self, input_ids: "torch.LongTensor", scores: "torch.FloatTensor", **kwargs) -> bool: - for stop_id in self.stop_ids: - found_stop_word = self.is_stop_word_found(input_ids, stop_id) - if found_stop_word: - return True - return False - - def is_stop_word_found(self, generated_text_ids: "torch.Tensor", stop_id: "torch.Tensor") -> bool: - generated_text_ids = generated_text_ids[-1] - len_generated_text_ids = generated_text_ids.size(0) - len_stop_id = stop_id.size(0) - result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id)) - return result + from transformers import StoppingCriteriaList, pipeline @component diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index 6cc984474..9ac3d1443 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -3,6 +3,7 @@ from unittest.mock import patch, Mock import pytest import torch +from transformers import PreTrainedTokenizerFast from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria @@ -362,7 +363,7 @@ class TestHuggingFaceLocalGenerator: # "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of # "unambiguously" in input_ids1 which is not correct. - stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"]) + stop_words_criteria = StopWordsCriteria(tokenizer=Mock(spec=PreTrainedTokenizerFast), stop_words=["mock data"]) # because we are mocking the tokenizer, we need to set the stop words manually stop_words_criteria.stop_ids = stop_words_id