fix: enable passing max_length for text2text-generation task (#5420)

* bug fix

* add unit test

* reformatting

* add release note

* add release note

* Update releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* bug fix

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Fanli Lin 2023-08-02 20:13:30 +08:00 committed by GitHub
parent 40a2e9b56a
commit 73fa796735
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 1 deletions

View File

@ -266,7 +266,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
if is_text_generation:
model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length)
else:
model_input_kwargs["max_length"] = self.max_length
model_input_kwargs["max_length"] = model_input_kwargs.pop("max_length", self.max_length)
if stream:
stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler()

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Enable setting the `max_length` value when running PromptNodes using local HF text2text-generation models.

View File

@ -523,6 +523,23 @@ def test_generation_kwargs_from_invoke():
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
@pytest.mark.unit
def test_max_length_from_invoke(mock_auto_tokenizer, mock_pipeline, mock_get_task):
"""
Test that max_length passed to invoke are passed to the underlying HF model
"""
query = "What does 42 mean?"
# test that generation_kwargs are passed to the underlying HF model
layer = HFLocalInvocationLayer()
layer.invoke(prompt=query, generation_kwargs={"max_length": 200})
# find the call to pipeline invocation, and check that the kwargs are correct
assert any((call.kwargs == {"max_length": 200}) and (query in call.args) for call in mock_pipeline.mock_calls)
layer = HFLocalInvocationLayer()
layer.invoke(prompt=query, generation_kwargs=GenerationConfig(max_length=235))
assert any((call.kwargs == {"max_length": 235}) and (query in call.args) for call in mock_pipeline.mock_calls)
@pytest.mark.unit
def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize