mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 01:09:43 +00:00
fix: Add model_max_length model_kwargs parameter to HF PromptNode (#4651)
This commit is contained in:
parent
d8ac30fa47
commit
1dd6158244
@ -48,8 +48,16 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
|
||||
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
|
||||
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
||||
For more details about these kwargs, see
|
||||
For more details about pipeline kwargs in general, see
|
||||
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
||||
|
||||
This layer supports two additional kwargs: generation_kwargs and model_max_length.
|
||||
|
||||
The generation_kwargs are used to customize text generation for the underlying pipeline. See Hugging
|
||||
Face [docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
|
||||
for more details.
|
||||
|
||||
The model_max_length is used to specify the custom sequence length for the underlying pipeline.
|
||||
"""
|
||||
super().__init__(model_name_or_path)
|
||||
self.use_auth_token = use_auth_token
|
||||
@ -79,6 +87,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
"torch_dtype",
|
||||
"device_map",
|
||||
"generation_kwargs",
|
||||
"model_max_length",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
@ -89,6 +98,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
# save generation_kwargs for pipeline invocation
|
||||
self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
|
||||
model_max_length = model_input_kwargs.pop("model_max_length", None)
|
||||
|
||||
torch_dtype = model_input_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
@ -121,6 +131,19 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
|
||||
self.max_length = max_length or self.pipe.model.config.max_length
|
||||
|
||||
# we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
|
||||
# embeddings and can accept sequences of more than 512 tokens
|
||||
if model_max_length is not None:
|
||||
self.pipe.tokenizer.model_max_length = model_max_length
|
||||
|
||||
if self.max_length > self.pipe.tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
"The max_length %s is greater than model_max_length %s. This might result in truncation of the "
|
||||
"generated text. Please lower the max_length (number of answer tokens) parameter!",
|
||||
self.max_length,
|
||||
self.pipe.tokenizer.model_max_length,
|
||||
)
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
|
||||
@ -194,9 +217,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
:param prompt: Prompt text to be sent to the generative model.
|
||||
"""
|
||||
model_max_length = self.pipe.tokenizer.model_max_length
|
||||
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
|
||||
n_answer_tokens = self.max_length
|
||||
if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length:
|
||||
if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
|
||||
return prompt
|
||||
|
||||
logger.warning(
|
||||
@ -204,14 +228,14 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
"answer length (%s tokens) fit within the max token limit (%s tokens). "
|
||||
"Shorten the prompt to prevent it from being cut off",
|
||||
n_prompt_tokens,
|
||||
self.pipe.tokenizer.model_max_length - n_answer_tokens,
|
||||
max(0, model_max_length - n_answer_tokens),
|
||||
n_answer_tokens,
|
||||
self.pipe.tokenizer.model_max_length,
|
||||
model_max_length,
|
||||
)
|
||||
|
||||
tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
|
||||
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
|
||||
tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens]
|
||||
tokenized_payload[: model_max_length - n_answer_tokens]
|
||||
)
|
||||
return decoded_string
|
||||
|
||||
|
||||
@ -302,6 +302,30 @@ def test_stop_words(prompt_model):
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_node_model_max_length(caplog):
|
||||
prompt = "This is a prompt " * 5 # (26 tokens with t5 flan tokenizer)
|
||||
|
||||
# test that model_max_length is set to 1024
|
||||
# test that model doesn't truncate the prompt if it is shorter than
|
||||
# the model max length minus the length of the output
|
||||
# no warning is raised
|
||||
node = PromptNode(model_kwargs={"model_max_length": 1024})
|
||||
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024
|
||||
with caplog.at_level(logging.WARNING):
|
||||
node.prompt(prompt)
|
||||
assert len(caplog.text) <= 0
|
||||
|
||||
# test that model_max_length is set to 10
|
||||
# test that model truncates the prompt if it is longer than the max length (10 tokens)
|
||||
# a warning is raised
|
||||
node = PromptNode(model_kwargs={"model_max_length": 10})
|
||||
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10
|
||||
with caplog.at_level(logging.WARNING):
|
||||
node.prompt(prompt)
|
||||
assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_node_streaming_handler_on_call(mock_model):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user