chore: Simplify DefaultPromptHandler logic and add tests (#4979)

* Simplify DefaultPromptHandler logic and add tests

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>

* Remove commented code

* Split single unit test into multiple tests

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
Silvano Cerza 2023-05-29 12:13:32 +02:00 committed by GitHub
parent 7001aee3fe
commit 37518c8b8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 11 deletions

View File

@ -72,16 +72,16 @@ class DefaultPromptHandler:
new_prompt_length = 0 new_prompt_length = 0
if prompt: if prompt:
prompt_length = len(self.tokenizer.tokenize(prompt)) tokenized_prompt = self.tokenizer.tokenize(prompt)
prompt_length = len(tokenized_prompt)
if (prompt_length + self.max_length) <= self.model_max_length: if (prompt_length + self.max_length) <= self.model_max_length:
resized_prompt = prompt resized_prompt = prompt
new_prompt_length = prompt_length new_prompt_length = prompt_length
else: else:
tokenized_payload = self.tokenizer.tokenize(prompt)
resized_prompt = self.tokenizer.convert_tokens_to_string( resized_prompt = self.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.model_max_length - self.max_length] tokenized_prompt[: self.model_max_length - self.max_length]
) )
new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length]) new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length])
return { return {
"resized_prompt": resized_prompt, "resized_prompt": resized_prompt,

View File

@ -37,11 +37,8 @@ def test_gpt2_prompt_handler():
@pytest.mark.integration @pytest.mark.integration
def test_flan_prompt_handler(): def test_flan_prompt_handler_no_resize():
# test google/flan-t5-xxl tokenizer
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10) handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == { assert handler("This is a test") == {
"prompt_length": 5, "prompt_length": 5,
"resized_prompt": "This is a test", "resized_prompt": "This is a test",
@ -50,7 +47,10 @@ def test_flan_prompt_handler():
"new_prompt_length": 5, "new_prompt_length": 5,
} }
# test resize
@pytest.mark.integration
def test_flan_prompt_handler_resize():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("This is a prompt that will be resized because it is longer than allowed") == { assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 17, "prompt_length": 17,
"resized_prompt": "This is a prompt that will be re", "resized_prompt": "This is a prompt that will be re",
@ -59,7 +59,10 @@ def test_flan_prompt_handler():
"new_prompt_length": 10, "new_prompt_length": 10,
} }
# test corner cases
@pytest.mark.integration
def test_flan_prompt_handler_empty_string():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("") == { assert handler("") == {
"prompt_length": 0, "prompt_length": 0,
"resized_prompt": "", "resized_prompt": "",
@ -68,7 +71,10 @@ def test_flan_prompt_handler():
"new_prompt_length": 0, "new_prompt_length": 0,
} }
# test corner case
@pytest.mark.integration
def test_flan_prompt_handler_none():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler(None) == { assert handler(None) == {
"prompt_length": 0, "prompt_length": 0,
"resized_prompt": None, "resized_prompt": None,