From 11440395f4f1055f52f483d0adcec9aff13e635b Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Fri, 1 Sep 2023 11:48:41 +0200 Subject: [PATCH] fix: Set model_max_length in the Tokenizer of `DefaultPromptHandler` (#5596) * Set model_max_length in tokenizer in prompt handler * Add release note --- haystack/nodes/prompt/invocation_layer/handlers.py | 1 + ...del_max_length-prompt_handler-7f34c40c62a8c55b.yaml | 4 ++++ test/prompt/test_handlers.py | 10 ++++++++++ 3 files changed, 15 insertions(+) create mode 100644 releasenotes/notes/fix-model_max_length-prompt_handler-7f34c40c62a8c55b.yaml diff --git a/haystack/nodes/prompt/invocation_layer/handlers.py b/haystack/nodes/prompt/invocation_layer/handlers.py index 446ddedf0..073561b3f 100644 --- a/haystack/nodes/prompt/invocation_layer/handlers.py +++ b/haystack/nodes/prompt/invocation_layer/handlers.py @@ -63,6 +63,7 @@ class DefaultPromptHandler: def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100): self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length self.max_length = max_length diff --git a/releasenotes/notes/fix-model_max_length-prompt_handler-7f34c40c62a8c55b.yaml b/releasenotes/notes/fix-model_max_length-prompt_handler-7f34c40c62a8c55b.yaml new file mode 100644 index 000000000..d2fe46866 --- /dev/null +++ b/releasenotes/notes/fix-model_max_length-prompt_handler-7f34c40c62a8c55b.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix model_max_length not being set in the Tokenizer in DefaultPromptHandler. diff --git a/test/prompt/test_handlers.py b/test/prompt/test_handlers.py index 2d6da1318..f7b8a3e9b 100644 --- a/test/prompt/test_handlers.py +++ b/test/prompt/test_handlers.py @@ -57,6 +57,13 @@ def test_prompt_handler_negative(): } +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained") +def test_prompt_handler_model_max_length_set_in_tokenizer(mock_tokenizer): + prompt_handler = DefaultPromptHandler(model_name_or_path="model_path", model_max_length=10, max_length=3) + assert prompt_handler.tokenizer.model_max_length == 10 + + @pytest.mark.integration def test_prompt_handler_basics(): handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10) @@ -65,6 +72,9 @@ def test_prompt_handler_basics(): handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20) assert handler.max_length == 100 + # test model_max_length is set in tokenizer + assert handler.tokenizer.model_max_length == 20 + @pytest.mark.integration def test_gpt2_prompt_handler():