chore: Simplify DefaultPromptHandler logic and add tests (#4979)

* Simplify DefaultPromptHandler logic and add tests

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>

* Remove commented code

* Split single unit test into multiple tests

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
Silvano Cerza 2023-05-29 12:13:32 +02:00 committed by GitHub
parent 7001aee3fe
commit 37518c8b8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 11 deletions

View File

@ -72,16 +72,16 @@ class DefaultPromptHandler:
new_prompt_length = 0
if prompt:
prompt_length = len(self.tokenizer.tokenize(prompt))
tokenized_prompt = self.tokenizer.tokenize(prompt)
prompt_length = len(tokenized_prompt)
if (prompt_length + self.max_length) <= self.model_max_length:
resized_prompt = prompt
new_prompt_length = prompt_length
else:
tokenized_payload = self.tokenizer.tokenize(prompt)
resized_prompt = self.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.model_max_length - self.max_length]
tokenized_prompt[: self.model_max_length - self.max_length]
)
new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length])
new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length])
return {
"resized_prompt": resized_prompt,

View File

@ -37,11 +37,8 @@ def test_gpt2_prompt_handler():
@pytest.mark.integration
def test_flan_prompt_handler():
# test google/flan-t5-xxl tokenizer
def test_flan_prompt_handler_no_resize():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == {
"prompt_length": 5,
"resized_prompt": "This is a test",
@ -50,7 +47,10 @@ def test_flan_prompt_handler():
"new_prompt_length": 5,
}
# test resize
@pytest.mark.integration
def test_flan_prompt_handler_resize():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 17,
"resized_prompt": "This is a prompt that will be re",
@ -59,7 +59,10 @@ def test_flan_prompt_handler():
"new_prompt_length": 10,
}
# test corner cases
@pytest.mark.integration
def test_flan_prompt_handler_empty_string():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("") == {
"prompt_length": 0,
"resized_prompt": "",
@ -68,7 +71,10 @@ def test_flan_prompt_handler():
"new_prompt_length": 0,
}
# test corner case
@pytest.mark.integration
def test_flan_prompt_handler_none():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler(None) == {
"prompt_length": 0,
"resized_prompt": None,