diff --git a/haystack/nodes/prompt/invocation_layer/handlers.py b/haystack/nodes/prompt/invocation_layer/handlers.py index 0872fc53e..5cce03617 100644 --- a/haystack/nodes/prompt/invocation_layer/handlers.py +++ b/haystack/nodes/prompt/invocation_layer/handlers.py @@ -72,16 +72,16 @@ class DefaultPromptHandler: new_prompt_length = 0 if prompt: - prompt_length = len(self.tokenizer.tokenize(prompt)) + tokenized_prompt = self.tokenizer.tokenize(prompt) + prompt_length = len(tokenized_prompt) if (prompt_length + self.max_length) <= self.model_max_length: resized_prompt = prompt new_prompt_length = prompt_length else: - tokenized_payload = self.tokenizer.tokenize(prompt) resized_prompt = self.tokenizer.convert_tokens_to_string( - tokenized_payload[: self.model_max_length - self.max_length] + tokenized_prompt[: self.model_max_length - self.max_length] ) - new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length]) + new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length]) return { "resized_prompt": resized_prompt, diff --git a/test/prompt/test_handlers.py b/test/prompt/test_handlers.py index 2f90eb4c4..4bf24d486 100644 --- a/test/prompt/test_handlers.py +++ b/test/prompt/test_handlers.py @@ -37,11 +37,8 @@ def test_gpt2_prompt_handler(): @pytest.mark.integration -def test_flan_prompt_handler(): - # test google/flan-t5-xxl tokenizer +def test_flan_prompt_handler_no_resize(): handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10) - - # test no resize assert handler("This is a test") == { "prompt_length": 5, "resized_prompt": "This is a test", @@ -50,7 +47,10 @@ def test_flan_prompt_handler(): "new_prompt_length": 5, } - # test resize + +@pytest.mark.integration +def test_flan_prompt_handler_resize(): + handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10) assert handler("This is a prompt that will be resized because it is longer than allowed") == { "prompt_length": 17, "resized_prompt": "This is a prompt that will be re", @@ -59,7 +59,10 @@ def test_flan_prompt_handler(): "new_prompt_length": 10, } - # test corner cases + +@pytest.mark.integration +def test_flan_prompt_handler_empty_string(): + handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10) assert handler("") == { "prompt_length": 0, "resized_prompt": "", @@ -68,7 +71,10 @@ def test_flan_prompt_handler(): "new_prompt_length": 0, } - # test corner case + +@pytest.mark.integration +def test_flan_prompt_handler_none(): + handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10) assert handler(None) == { "prompt_length": 0, "resized_prompt": None,