2023-05-12 17:50:09 +02:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_prompt_handler_basics():
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
|
|
|
assert callable(handler)
|
|
|
|
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
|
|
|
|
assert handler.max_length == 100
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_gpt2_prompt_handler():
|
|
|
|
# test gpt2 BPE based tokenizer
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
|
|
|
|
|
|
|
# test no resize
|
|
|
|
assert handler("This is a test") == {
|
|
|
|
"prompt_length": 4,
|
|
|
|
"resized_prompt": "This is a test",
|
|
|
|
"max_length": 10,
|
|
|
|
"model_max_length": 20,
|
|
|
|
"new_prompt_length": 4,
|
|
|
|
}
|
|
|
|
|
|
|
|
# test resize
|
|
|
|
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
|
|
|
"prompt_length": 15,
|
|
|
|
"resized_prompt": "This is a prompt that will be resized because",
|
|
|
|
"max_length": 10,
|
|
|
|
"model_max_length": 20,
|
|
|
|
"new_prompt_length": 10,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.integration
|
2023-05-29 12:13:32 +02:00
|
|
|
def test_flan_prompt_handler_no_resize():
|
2023-05-12 17:50:09 +02:00
|
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
|
|
|
assert handler("This is a test") == {
|
|
|
|
"prompt_length": 5,
|
|
|
|
"resized_prompt": "This is a test",
|
|
|
|
"max_length": 10,
|
|
|
|
"model_max_length": 20,
|
|
|
|
"new_prompt_length": 5,
|
|
|
|
}
|
|
|
|
|
2023-05-29 12:13:32 +02:00
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_flan_prompt_handler_resize():
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
2023-05-12 17:50:09 +02:00
|
|
|
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
|
|
|
"prompt_length": 17,
|
|
|
|
"resized_prompt": "This is a prompt that will be re",
|
|
|
|
"max_length": 10,
|
|
|
|
"model_max_length": 20,
|
|
|
|
"new_prompt_length": 10,
|
|
|
|
}
|
2023-05-22 14:45:53 +02:00
|
|
|
|
2023-05-29 12:13:32 +02:00
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_flan_prompt_handler_empty_string():
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
2023-05-22 14:45:53 +02:00
|
|
|
assert handler("") == {
|
|
|
|
"prompt_length": 0,
|
|
|
|
"resized_prompt": "",
|
|
|
|
"max_length": 10,
|
|
|
|
"model_max_length": 20,
|
|
|
|
"new_prompt_length": 0,
|
|
|
|
}
|
|
|
|
|
2023-05-29 12:13:32 +02:00
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_flan_prompt_handler_none():
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
2023-05-22 14:45:53 +02:00
|
|
|
assert handler(None) == {
|
|
|
|
"prompt_length": 0,
|
|
|
|
"resized_prompt": None,
|
|
|
|
"max_length": 10,
|
|
|
|
"model_max_length": 20,
|
|
|
|
"new_prompt_length": 0,
|
|
|
|
}
|