fix: Add model_max_length model_kwargs parameter to HF PromptNode (#4651)

This commit is contained in:
Vladimir Blagojevic 2023-04-14 15:40:42 +02:00 committed by GitHub
parent d8ac30fa47
commit 1dd6158244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 5 deletions

View File

@ -48,8 +48,16 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map. includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about these kwargs, see For more details about pipeline kwargs in general, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline). Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
This layer supports two additional kwargs: generation_kwargs and model_max_length.
The generation_kwargs are used to customize text generation for the underlying pipeline. See Hugging
Face [docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
for more details.
The model_max_length is used to specify the custom sequence length for the underlying pipeline.
""" """
super().__init__(model_name_or_path) super().__init__(model_name_or_path)
self.use_auth_token = use_auth_token self.use_auth_token = use_auth_token
@ -79,6 +87,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
"torch_dtype", "torch_dtype",
"device_map", "device_map",
"generation_kwargs", "generation_kwargs",
"model_max_length",
] ]
if key in kwargs if key in kwargs
} }
@ -89,6 +98,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
# save generation_kwargs for pipeline invocation # save generation_kwargs for pipeline invocation
self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {}) self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
model_max_length = model_input_kwargs.pop("model_max_length", None)
torch_dtype = model_input_kwargs.get("torch_dtype") torch_dtype = model_input_kwargs.get("torch_dtype")
if torch_dtype is not None: if torch_dtype is not None:
@ -121,6 +131,19 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail. # max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or self.pipe.model.config.max_length self.max_length = max_length or self.pipe.model.config.max_length
# we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
# embeddings and can accept sequences of more than 512 tokens
if model_max_length is not None:
self.pipe.tokenizer.model_max_length = model_max_length
if self.max_length > self.pipe.tokenizer.model_max_length:
logger.warning(
"The max_length %s is greater than model_max_length %s. This might result in truncation of the "
"generated text. Please lower the max_length (number of answer tokens) parameter!",
self.max_length,
self.pipe.tokenizer.model_max_length,
)
def invoke(self, *args, **kwargs): def invoke(self, *args, **kwargs):
""" """
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
@ -194,9 +217,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
:param prompt: Prompt text to be sent to the generative model. :param prompt: Prompt text to be sent to the generative model.
""" """
model_max_length = self.pipe.tokenizer.model_max_length
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt)) n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
n_answer_tokens = self.max_length n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length: if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
return prompt return prompt
logger.warning( logger.warning(
@ -204,14 +228,14 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
"answer length (%s tokens) fit within the max token limit (%s tokens). " "answer length (%s tokens) fit within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off", "Shorten the prompt to prevent it from being cut off",
n_prompt_tokens, n_prompt_tokens,
self.pipe.tokenizer.model_max_length - n_answer_tokens, max(0, model_max_length - n_answer_tokens),
n_answer_tokens, n_answer_tokens,
self.pipe.tokenizer.model_max_length, model_max_length,
) )
tokenized_payload = self.pipe.tokenizer.tokenize(prompt) tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
decoded_string = self.pipe.tokenizer.convert_tokens_to_string( decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens] tokenized_payload[: model_max_length - n_answer_tokens]
) )
return decoded_string return decoded_string

View File

@ -302,6 +302,30 @@ def test_stop_words(prompt_model):
assert "capital" in r[0] or "Germany" in r[0] assert "capital" in r[0] or "Germany" in r[0]
@pytest.mark.unit
def test_prompt_node_model_max_length(caplog):
prompt = "This is a prompt " * 5 # (26 tokens with t5 flan tokenizer)
# test that model_max_length is set to 1024
# test that model doesn't truncate the prompt if it is shorter than
# the model max length minus the length of the output
# no warning is raised
node = PromptNode(model_kwargs={"model_max_length": 1024})
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024
with caplog.at_level(logging.WARNING):
node.prompt(prompt)
assert len(caplog.text) <= 0
# test that model_max_length is set to 10
# test that model truncates the prompt if it is longer than the max length (10 tokens)
# a warning is raised
node = PromptNode(model_kwargs={"model_max_length": 10})
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10
with caplog.at_level(logging.WARNING):
node.prompt(prompt)
assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text
@pytest.mark.unit @pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel") @patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_streaming_handler_on_call(mock_model): def test_prompt_node_streaming_handler_on_call(mock_model):