mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	fix: Add model_max_length model_kwargs parameter to HF PromptNode (#4651)
This commit is contained in:
		
							parent
							
								
									d8ac30fa47
								
							
						
					
					
						commit
						1dd6158244
					
				@ -48,8 +48,16 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
 | 
				
			|||||||
        all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
 | 
					        all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
 | 
				
			||||||
        kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
 | 
					        kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
 | 
				
			||||||
        includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
 | 
					        includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
 | 
				
			||||||
        For more details about these kwargs, see
 | 
					        For more details about pipeline kwargs in general, see
 | 
				
			||||||
        Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
 | 
					        Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This layer supports two additional kwargs: generation_kwargs and model_max_length.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The generation_kwargs are used to customize text generation for the underlying pipeline. See Hugging
 | 
				
			||||||
 | 
					        Face [docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
 | 
				
			||||||
 | 
					        for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The model_max_length is used to specify the custom sequence length for the underlying pipeline.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        super().__init__(model_name_or_path)
 | 
					        super().__init__(model_name_or_path)
 | 
				
			||||||
        self.use_auth_token = use_auth_token
 | 
					        self.use_auth_token = use_auth_token
 | 
				
			||||||
@ -79,6 +87,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
 | 
				
			|||||||
                "torch_dtype",
 | 
					                "torch_dtype",
 | 
				
			||||||
                "device_map",
 | 
					                "device_map",
 | 
				
			||||||
                "generation_kwargs",
 | 
					                "generation_kwargs",
 | 
				
			||||||
 | 
					                "model_max_length",
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
            if key in kwargs
 | 
					            if key in kwargs
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -89,6 +98,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # save generation_kwargs for pipeline invocation
 | 
					        # save generation_kwargs for pipeline invocation
 | 
				
			||||||
        self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
 | 
					        self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
 | 
				
			||||||
 | 
					        model_max_length = model_input_kwargs.pop("model_max_length", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        torch_dtype = model_input_kwargs.get("torch_dtype")
 | 
					        torch_dtype = model_input_kwargs.get("torch_dtype")
 | 
				
			||||||
        if torch_dtype is not None:
 | 
					        if torch_dtype is not None:
 | 
				
			||||||
@ -121,6 +131,19 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
 | 
				
			|||||||
        # max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
 | 
					        # max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
 | 
				
			||||||
        self.max_length = max_length or self.pipe.model.config.max_length
 | 
					        self.max_length = max_length or self.pipe.model.config.max_length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
 | 
				
			||||||
 | 
					        # embeddings and can accept sequences of more than 512 tokens
 | 
				
			||||||
 | 
					        if model_max_length is not None:
 | 
				
			||||||
 | 
					            self.pipe.tokenizer.model_max_length = model_max_length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.max_length > self.pipe.tokenizer.model_max_length:
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                "The max_length %s is greater than model_max_length %s. This might result in truncation of the "
 | 
				
			||||||
 | 
					                "generated text. Please lower the max_length (number of answer tokens) parameter!",
 | 
				
			||||||
 | 
					                self.max_length,
 | 
				
			||||||
 | 
					                self.pipe.tokenizer.model_max_length,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def invoke(self, *args, **kwargs):
 | 
					    def invoke(self, *args, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
 | 
					        It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
 | 
				
			||||||
@ -194,9 +217,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        :param prompt: Prompt text to be sent to the generative model.
 | 
					        :param prompt: Prompt text to be sent to the generative model.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        model_max_length = self.pipe.tokenizer.model_max_length
 | 
				
			||||||
        n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
 | 
					        n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
 | 
				
			||||||
        n_answer_tokens = self.max_length
 | 
					        n_answer_tokens = self.max_length
 | 
				
			||||||
        if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length:
 | 
					        if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
 | 
				
			||||||
            return prompt
 | 
					            return prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logger.warning(
 | 
					        logger.warning(
 | 
				
			||||||
@ -204,14 +228,14 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
 | 
				
			|||||||
            "answer length (%s tokens) fit within the max token limit (%s tokens). "
 | 
					            "answer length (%s tokens) fit within the max token limit (%s tokens). "
 | 
				
			||||||
            "Shorten the prompt to prevent it from being cut off",
 | 
					            "Shorten the prompt to prevent it from being cut off",
 | 
				
			||||||
            n_prompt_tokens,
 | 
					            n_prompt_tokens,
 | 
				
			||||||
            self.pipe.tokenizer.model_max_length - n_answer_tokens,
 | 
					            max(0, model_max_length - n_answer_tokens),
 | 
				
			||||||
            n_answer_tokens,
 | 
					            n_answer_tokens,
 | 
				
			||||||
            self.pipe.tokenizer.model_max_length,
 | 
					            model_max_length,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
 | 
					        tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
 | 
				
			||||||
        decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
 | 
					        decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
 | 
				
			||||||
            tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens]
 | 
					            tokenized_payload[: model_max_length - n_answer_tokens]
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        return decoded_string
 | 
					        return decoded_string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -302,6 +302,30 @@ def test_stop_words(prompt_model):
 | 
				
			|||||||
    assert "capital" in r[0] or "Germany" in r[0]
 | 
					    assert "capital" in r[0] or "Germany" in r[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.unit
 | 
				
			||||||
 | 
					def test_prompt_node_model_max_length(caplog):
 | 
				
			||||||
 | 
					    prompt = "This is a prompt " * 5  # (26 tokens with t5 flan tokenizer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test that model_max_length is set to 1024
 | 
				
			||||||
 | 
					    # test that model doesn't truncate the prompt if it is shorter than
 | 
				
			||||||
 | 
					    # the model max length minus the length of the output
 | 
				
			||||||
 | 
					    # no warning is raised
 | 
				
			||||||
 | 
					    node = PromptNode(model_kwargs={"model_max_length": 1024})
 | 
				
			||||||
 | 
					    assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024
 | 
				
			||||||
 | 
					    with caplog.at_level(logging.WARNING):
 | 
				
			||||||
 | 
					        node.prompt(prompt)
 | 
				
			||||||
 | 
					        assert len(caplog.text) <= 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test that model_max_length is set to 10
 | 
				
			||||||
 | 
					    # test that model truncates the prompt if it is longer than the max length (10 tokens)
 | 
				
			||||||
 | 
					    # a warning is raised
 | 
				
			||||||
 | 
					    node = PromptNode(model_kwargs={"model_max_length": 10})
 | 
				
			||||||
 | 
					    assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10
 | 
				
			||||||
 | 
					    with caplog.at_level(logging.WARNING):
 | 
				
			||||||
 | 
					        node.prompt(prompt)
 | 
				
			||||||
 | 
					        assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.unit
 | 
					@pytest.mark.unit
 | 
				
			||||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
 | 
					@patch("haystack.nodes.prompt.prompt_node.PromptModel")
 | 
				
			||||||
def test_prompt_node_streaming_handler_on_call(mock_model):
 | 
					def test_prompt_node_streaming_handler_on_call(mock_model):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user