haystack/test/prompt/test_handlers.py

117 lines
3.6 KiB
Python
Raw Normal View History

from unittest.mock import patch
import pytest
from haystack.nodes.prompt.invocation_layer.handlers import (
DefaultTokenStreamingHandler,
DefaultPromptHandler,
AnthropicTokenStreamingHandler,
)
@pytest.mark.integration
def test_prompt_handler_basics():
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
assert callable(handler)
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
assert handler.max_length == 100
@pytest.mark.integration
def test_gpt2_prompt_handler():
# test gpt2 BPE based tokenizer
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == {
"prompt_length": 4,
"resized_prompt": "This is a test",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 4,
}
# test resize
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 15,
"resized_prompt": "This is a prompt that will be resized because",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 10,
}
@pytest.mark.integration
def test_flan_prompt_handler_no_resize():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("This is a test") == {
"prompt_length": 5,
"resized_prompt": "This is a test",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 5,
}
@pytest.mark.integration
def test_flan_prompt_handler_resize():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 17,
"resized_prompt": "This is a prompt that will be re",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 10,
}
@pytest.mark.integration
def test_flan_prompt_handler_empty_string():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler("") == {
"prompt_length": 0,
"resized_prompt": "",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 0,
}
@pytest.mark.integration
def test_flan_prompt_handler_none():
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
assert handler(None) == {
"prompt_length": 0,
"resized_prompt": None,
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 0,
}
@pytest.mark.unit
@patch("builtins.print")
def test_anthropic_token_streaming_handler(mock_print):
handler = AnthropicTokenStreamingHandler(DefaultTokenStreamingHandler())
res = handler(" This")
assert res == " This"
mock_print.assert_called_with(" This", flush=True, end="")
res = handler(" This is a new")
assert res == " is a new"
mock_print.assert_called_with(" is a new", flush=True, end="")
res = handler(" This is a new token")
assert res == " token"
mock_print.assert_called_with(" token", flush=True, end="")
res = handler("And now")
assert res == "And now"
mock_print.assert_called_with("And now", flush=True, end="")
res = handler("And now something completely different")
assert res == " something completely different"
mock_print.assert_called_with(" something completely different", flush=True, end="")