fix: Add model_max_length model_kwargs parameter to HF PromptNode (#4651)

This commit is contained in:
Vladimir Blagojevic 2023-04-14 15:40:42 +02:00 committed by GitHub
parent d8ac30fa47
commit 1dd6158244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 5 deletions

View File

@ -48,8 +48,16 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about these kwargs, see
For more details about pipeline kwargs in general, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
This layer supports two additional kwargs: generation_kwargs and model_max_length.
The generation_kwargs are used to customize text generation for the underlying pipeline. See Hugging
Face [docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
for more details.
The model_max_length is used to specify the custom sequence length for the underlying pipeline.
"""
super().__init__(model_name_or_path)
self.use_auth_token = use_auth_token
@ -79,6 +87,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
"torch_dtype",
"device_map",
"generation_kwargs",
"model_max_length",
]
if key in kwargs
}
@ -89,6 +98,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
# save generation_kwargs for pipeline invocation
self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
model_max_length = model_input_kwargs.pop("model_max_length", None)
torch_dtype = model_input_kwargs.get("torch_dtype")
if torch_dtype is not None:
@ -121,6 +131,19 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or self.pipe.model.config.max_length
# we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
# embeddings and can accept sequences of more than 512 tokens
if model_max_length is not None:
self.pipe.tokenizer.model_max_length = model_max_length
if self.max_length > self.pipe.tokenizer.model_max_length:
logger.warning(
"The max_length %s is greater than model_max_length %s. This might result in truncation of the "
"generated text. Please lower the max_length (number of answer tokens) parameter!",
self.max_length,
self.pipe.tokenizer.model_max_length,
)
def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
@ -194,9 +217,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
:param prompt: Prompt text to be sent to the generative model.
"""
model_max_length = self.pipe.tokenizer.model_max_length
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length:
if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
return prompt
logger.warning(
@ -204,14 +228,14 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
"answer length (%s tokens) fit within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off",
n_prompt_tokens,
self.pipe.tokenizer.model_max_length - n_answer_tokens,
max(0, model_max_length - n_answer_tokens),
n_answer_tokens,
self.pipe.tokenizer.model_max_length,
model_max_length,
)
tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens]
tokenized_payload[: model_max_length - n_answer_tokens]
)
return decoded_string

View File

@ -302,6 +302,30 @@ def test_stop_words(prompt_model):
assert "capital" in r[0] or "Germany" in r[0]
@pytest.mark.unit
def test_prompt_node_model_max_length(caplog):
prompt = "This is a prompt " * 5 # (26 tokens with t5 flan tokenizer)
# test that model_max_length is set to 1024
# test that model doesn't truncate the prompt if it is shorter than
# the model max length minus the length of the output
# no warning is raised
node = PromptNode(model_kwargs={"model_max_length": 1024})
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024
with caplog.at_level(logging.WARNING):
node.prompt(prompt)
assert len(caplog.text) <= 0
# test that model_max_length is set to 10
# test that model truncates the prompt if it is longer than the max length (10 tokens)
# a warning is raised
node = PromptNode(model_kwargs={"model_max_length": 10})
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10
with caplog.at_level(logging.WARNING):
node.prompt(prompt)
assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_streaming_handler_on_call(mock_model):