mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
fix: enable passing max_length for text2text-generation task (#5420)
* bug fix * add unit test * reformatting * add release note * add release note * Update releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml Co-authored-by: bogdankostic <bogdankostic@web.de> * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update test/prompt/invocation_layer/test_hugging_face.py Co-authored-by: bogdankostic <bogdankostic@web.de> * bug fix --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
40a2e9b56a
commit
73fa796735
@ -266,7 +266,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
if is_text_generation:
|
||||
model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length)
|
||||
else:
|
||||
model_input_kwargs["max_length"] = self.max_length
|
||||
model_input_kwargs["max_length"] = model_input_kwargs.pop("max_length", self.max_length)
|
||||
|
||||
if stream:
|
||||
stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler()
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Enable setting the `max_length` value when running PromptNodes using local HF text2text-generation models.
|
||||
@ -523,6 +523,23 @@ def test_generation_kwargs_from_invoke():
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_max_length_from_invoke(mock_auto_tokenizer, mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test that max_length passed to invoke are passed to the underlying HF model
|
||||
"""
|
||||
query = "What does 42 mean?"
|
||||
# test that generation_kwargs are passed to the underlying HF model
|
||||
layer = HFLocalInvocationLayer()
|
||||
layer.invoke(prompt=query, generation_kwargs={"max_length": 200})
|
||||
# find the call to pipeline invocation, and check that the kwargs are correct
|
||||
assert any((call.kwargs == {"max_length": 200}) and (query in call.args) for call in mock_pipeline.mock_calls)
|
||||
|
||||
layer = HFLocalInvocationLayer()
|
||||
layer.invoke(prompt=query, generation_kwargs=GenerationConfig(max_length=235))
|
||||
assert any((call.kwargs == {"max_length": 235}) and (query in call.args) for call in mock_pipeline.mock_calls)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
|
||||
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user