From 73fa7967351db597e3ffb154ca7ee6c11b8098d0 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 2 Aug 2023 20:13:30 +0800 Subject: [PATCH] fix: enable passing `max_length` for text2text-generation task (#5420) * bug fix * add unit test * reformatting * add release note * add release note * Update releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml Co-authored-by: bogdankostic * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic * bug fix --------- Co-authored-by: bogdankostic --- .../prompt/invocation_layer/hugging_face.py | 2 +- ...-length-during-runtime-097d65e537bf800b.yaml | 4 ++++ .../invocation_layer/test_hugging_face.py | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index b96af10ad..278e83e1a 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -266,7 +266,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): if is_text_generation: model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length) else: - model_input_kwargs["max_length"] = self.max_length + model_input_kwargs["max_length"] = model_input_kwargs.pop("max_length", self.max_length) if stream: stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler() diff --git a/releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml b/releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml new file mode 100644 index 000000000..32c5219fd --- /dev/null +++ b/releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Enable setting the `max_length` value when running PromptNodes using local HF text2text-generation models. diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index 5e45d39d0..e2e721a7c 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -523,6 +523,23 @@ def test_generation_kwargs_from_invoke(): mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {}) +@pytest.mark.unit +def test_max_length_from_invoke(mock_auto_tokenizer, mock_pipeline, mock_get_task): + """ + Test that max_length passed to invoke are passed to the underlying HF model + """ + query = "What does 42 mean?" + # test that generation_kwargs are passed to the underlying HF model + layer = HFLocalInvocationLayer() + layer.invoke(prompt=query, generation_kwargs={"max_length": 200}) + # find the call to pipeline invocation, and check that the kwargs are correct + assert any((call.kwargs == {"max_length": 200}) and (query in call.args) for call in mock_pipeline.mock_calls) + + layer = HFLocalInvocationLayer() + layer.invoke(prompt=query, generation_kwargs=GenerationConfig(max_length=235)) + assert any((call.kwargs == {"max_length": 235}) and (query in call.args) for call in mock_pipeline.mock_calls) + + @pytest.mark.unit def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer): # prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize