mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 15:59:14 +00:00
chore: Simplify DefaultPromptHandler logic and add tests (#4979)
* Simplify DefaultPromptHandler logic and add tests Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com> * Remove commented code * Split single unit test into multiple tests --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
parent
7001aee3fe
commit
37518c8b8c
@ -72,16 +72,16 @@ class DefaultPromptHandler:
|
||||
new_prompt_length = 0
|
||||
|
||||
if prompt:
|
||||
prompt_length = len(self.tokenizer.tokenize(prompt))
|
||||
tokenized_prompt = self.tokenizer.tokenize(prompt)
|
||||
prompt_length = len(tokenized_prompt)
|
||||
if (prompt_length + self.max_length) <= self.model_max_length:
|
||||
resized_prompt = prompt
|
||||
new_prompt_length = prompt_length
|
||||
else:
|
||||
tokenized_payload = self.tokenizer.tokenize(prompt)
|
||||
resized_prompt = self.tokenizer.convert_tokens_to_string(
|
||||
tokenized_payload[: self.model_max_length - self.max_length]
|
||||
tokenized_prompt[: self.model_max_length - self.max_length]
|
||||
)
|
||||
new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length])
|
||||
new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length])
|
||||
|
||||
return {
|
||||
"resized_prompt": resized_prompt,
|
||||
|
||||
@ -37,11 +37,8 @@ def test_gpt2_prompt_handler():
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_flan_prompt_handler():
|
||||
# test google/flan-t5-xxl tokenizer
|
||||
def test_flan_prompt_handler_no_resize():
|
||||
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||
|
||||
# test no resize
|
||||
assert handler("This is a test") == {
|
||||
"prompt_length": 5,
|
||||
"resized_prompt": "This is a test",
|
||||
@ -50,7 +47,10 @@ def test_flan_prompt_handler():
|
||||
"new_prompt_length": 5,
|
||||
}
|
||||
|
||||
# test resize
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_flan_prompt_handler_resize():
|
||||
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
||||
"prompt_length": 17,
|
||||
"resized_prompt": "This is a prompt that will be re",
|
||||
@ -59,7 +59,10 @@ def test_flan_prompt_handler():
|
||||
"new_prompt_length": 10,
|
||||
}
|
||||
|
||||
# test corner cases
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_flan_prompt_handler_empty_string():
|
||||
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||
assert handler("") == {
|
||||
"prompt_length": 0,
|
||||
"resized_prompt": "",
|
||||
@ -68,7 +71,10 @@ def test_flan_prompt_handler():
|
||||
"new_prompt_length": 0,
|
||||
}
|
||||
|
||||
# test corner case
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_flan_prompt_handler_none():
|
||||
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||
assert handler(None) == {
|
||||
"prompt_length": 0,
|
||||
"resized_prompt": None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user