mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-28 00:08:41 +00:00
chore: Simplify DefaultPromptHandler logic and add tests (#4979)
* Simplify DefaultPromptHandler logic and add tests Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com> * Remove commented code * Split single unit test into multiple tests --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
This commit is contained in:
parent
7001aee3fe
commit
37518c8b8c
@ -72,16 +72,16 @@ class DefaultPromptHandler:
|
|||||||
new_prompt_length = 0
|
new_prompt_length = 0
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
prompt_length = len(self.tokenizer.tokenize(prompt))
|
tokenized_prompt = self.tokenizer.tokenize(prompt)
|
||||||
|
prompt_length = len(tokenized_prompt)
|
||||||
if (prompt_length + self.max_length) <= self.model_max_length:
|
if (prompt_length + self.max_length) <= self.model_max_length:
|
||||||
resized_prompt = prompt
|
resized_prompt = prompt
|
||||||
new_prompt_length = prompt_length
|
new_prompt_length = prompt_length
|
||||||
else:
|
else:
|
||||||
tokenized_payload = self.tokenizer.tokenize(prompt)
|
|
||||||
resized_prompt = self.tokenizer.convert_tokens_to_string(
|
resized_prompt = self.tokenizer.convert_tokens_to_string(
|
||||||
tokenized_payload[: self.model_max_length - self.max_length]
|
tokenized_prompt[: self.model_max_length - self.max_length]
|
||||||
)
|
)
|
||||||
new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length])
|
new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"resized_prompt": resized_prompt,
|
"resized_prompt": resized_prompt,
|
||||||
|
|||||||
@ -37,11 +37,8 @@ def test_gpt2_prompt_handler():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_flan_prompt_handler():
|
def test_flan_prompt_handler_no_resize():
|
||||||
# test google/flan-t5-xxl tokenizer
|
|
||||||
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||||
|
|
||||||
# test no resize
|
|
||||||
assert handler("This is a test") == {
|
assert handler("This is a test") == {
|
||||||
"prompt_length": 5,
|
"prompt_length": 5,
|
||||||
"resized_prompt": "This is a test",
|
"resized_prompt": "This is a test",
|
||||||
@ -50,7 +47,10 @@ def test_flan_prompt_handler():
|
|||||||
"new_prompt_length": 5,
|
"new_prompt_length": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
# test resize
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_flan_prompt_handler_resize():
|
||||||
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||||
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
||||||
"prompt_length": 17,
|
"prompt_length": 17,
|
||||||
"resized_prompt": "This is a prompt that will be re",
|
"resized_prompt": "This is a prompt that will be re",
|
||||||
@ -59,7 +59,10 @@ def test_flan_prompt_handler():
|
|||||||
"new_prompt_length": 10,
|
"new_prompt_length": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
# test corner cases
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_flan_prompt_handler_empty_string():
|
||||||
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||||
assert handler("") == {
|
assert handler("") == {
|
||||||
"prompt_length": 0,
|
"prompt_length": 0,
|
||||||
"resized_prompt": "",
|
"resized_prompt": "",
|
||||||
@ -68,7 +71,10 @@ def test_flan_prompt_handler():
|
|||||||
"new_prompt_length": 0,
|
"new_prompt_length": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# test corner case
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_flan_prompt_handler_none():
|
||||||
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||||
assert handler(None) == {
|
assert handler(None) == {
|
||||||
"prompt_length": 0,
|
"prompt_length": 0,
|
||||||
"resized_prompt": None,
|
"resized_prompt": None,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user