From 1dd61582448d65b8c0080bd6181c4dc28b39581e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 14 Apr 2023 15:40:42 +0200 Subject: [PATCH] fix: Add model_max_length model_kwargs parameter to HF PromptNode (#4651) --- .../prompt/invocation_layer/hugging_face.py | 34 ++++++++++++++++--- test/prompt/test_prompt_node.py | 24 +++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index 7a8a6d61e..e162b422d 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -48,8 +48,16 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map. - For more details about these kwargs, see + For more details about pipeline kwargs in general, see Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline). + + This layer supports two additional kwargs: generation_kwargs and model_max_length. + + The generation_kwargs are used to customize text generation for the underlying pipeline. See Hugging + Face [docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation) + for more details. + + The model_max_length is used to specify the custom sequence length for the underlying pipeline. """ super().__init__(model_name_or_path) self.use_auth_token = use_auth_token @@ -79,6 +87,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): "torch_dtype", "device_map", "generation_kwargs", + "model_max_length", ] if key in kwargs } @@ -89,6 +98,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): # save generation_kwargs for pipeline invocation self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {}) + model_max_length = model_input_kwargs.pop("model_max_length", None) torch_dtype = model_input_kwargs.get("torch_dtype") if torch_dtype is not None: @@ -121,6 +131,19 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): # max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail. self.max_length = max_length or self.pipe.model.config.max_length + # we allow users to override the tokenizer's model_max_length because models like T5 have relative positional + # embeddings and can accept sequences of more than 512 tokens + if model_max_length is not None: + self.pipe.tokenizer.model_max_length = model_max_length + + if self.max_length > self.pipe.tokenizer.model_max_length: + logger.warning( + "The max_length %s is greater than model_max_length %s. This might result in truncation of the " + "generated text. Please lower the max_length (number of answer tokens) parameter!", + self.max_length, + self.pipe.tokenizer.model_max_length, + ) + def invoke(self, *args, **kwargs): """ It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model @@ -194,9 +217,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): :param prompt: Prompt text to be sent to the generative model. """ + model_max_length = self.pipe.tokenizer.model_max_length n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt)) n_answer_tokens = self.max_length - if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length: + if (n_prompt_tokens + n_answer_tokens) <= model_max_length: return prompt logger.warning( @@ -204,14 +228,14 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): "answer length (%s tokens) fit within the max token limit (%s tokens). " "Shorten the prompt to prevent it from being cut off", n_prompt_tokens, - self.pipe.tokenizer.model_max_length - n_answer_tokens, + max(0, model_max_length - n_answer_tokens), n_answer_tokens, - self.pipe.tokenizer.model_max_length, + model_max_length, ) tokenized_payload = self.pipe.tokenizer.tokenize(prompt) decoded_string = self.pipe.tokenizer.convert_tokens_to_string( - tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens] + tokenized_payload[: model_max_length - n_answer_tokens] ) return decoded_string diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index d6c918854..fa97a8757 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -302,6 +302,30 @@ def test_stop_words(prompt_model): assert "capital" in r[0] or "Germany" in r[0] +@pytest.mark.unit +def test_prompt_node_model_max_length(caplog): + prompt = "This is a prompt " * 5 # (26 tokens with t5 flan tokenizer) + + # test that model_max_length is set to 1024 + # test that model doesn't truncate the prompt if it is shorter than + # the model max length minus the length of the output + # no warning is raised + node = PromptNode(model_kwargs={"model_max_length": 1024}) + assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024 + with caplog.at_level(logging.WARNING): + node.prompt(prompt) + assert len(caplog.text) <= 0 + + # test that model_max_length is set to 10 + # test that model truncates the prompt if it is longer than the max length (10 tokens) + # a warning is raised + node = PromptNode(model_kwargs={"model_max_length": 10}) + assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10 + with caplog.at_level(logging.WARNING): + node.prompt(prompt) + assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text + + @pytest.mark.unit @patch("haystack.nodes.prompt.prompt_node.PromptModel") def test_prompt_node_streaming_handler_on_call(mock_model):