mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-11 10:07:50 +00:00
169 lines
5.5 KiB
Python
169 lines
5.5 KiB
Python
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from haystack.nodes.prompt.invocation_layer.handlers import (
|
|
DefaultTokenStreamingHandler,
|
|
DefaultPromptHandler,
|
|
AnthropicTokenStreamingHandler,
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_prompt_handler_positive():
|
|
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
|
|
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
|
|
mock_prompt = "I am a tokenized prompt"
|
|
|
|
with patch(
|
|
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
|
|
) as mock_tokenizer:
|
|
tokenizer_instance = mock_tokenizer.return_value
|
|
tokenizer_instance.tokenize.return_value = mock_tokens
|
|
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
|
|
|
|
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
|
|
|
|
# Test with a prompt that does not exceed model_max_length when tokenized
|
|
result = prompt_handler(mock_prompt)
|
|
|
|
assert result == {
|
|
"resized_prompt": mock_prompt,
|
|
"prompt_length": 5,
|
|
"new_prompt_length": 5,
|
|
"model_max_length": 10,
|
|
"max_length": 3,
|
|
}
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_prompt_handler_negative():
|
|
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
|
|
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
|
|
mock_prompt = "I am a tokenized prompt of length"
|
|
|
|
with patch(
|
|
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
|
|
) as mock_tokenizer:
|
|
tokenizer_instance = mock_tokenizer.return_value
|
|
tokenizer_instance.tokenize.return_value = mock_tokens
|
|
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
|
|
|
|
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
|
|
result = prompt_handler(mock_prompt)
|
|
|
|
assert result == {
|
|
"resized_prompt": mock_prompt,
|
|
"prompt_length": 8,
|
|
"new_prompt_length": 7,
|
|
"model_max_length": 10,
|
|
"max_length": 3,
|
|
}
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_prompt_handler_basics():
|
|
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
|
assert callable(handler)
|
|
|
|
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
|
|
assert handler.max_length == 100
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_gpt2_prompt_handler():
|
|
# test gpt2 BPE based tokenizer
|
|
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
|
|
|
# test no resize
|
|
assert handler("This is a test") == {
|
|
"prompt_length": 4,
|
|
"resized_prompt": "This is a test",
|
|
"max_length": 10,
|
|
"model_max_length": 20,
|
|
"new_prompt_length": 4,
|
|
}
|
|
|
|
# test resize
|
|
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
|
"prompt_length": 15,
|
|
"resized_prompt": "This is a prompt that will be resized because",
|
|
"max_length": 10,
|
|
"model_max_length": 20,
|
|
"new_prompt_length": 10,
|
|
}
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_flan_prompt_handler_no_resize():
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
|
assert handler("This is a test") == {
|
|
"prompt_length": 5,
|
|
"resized_prompt": "This is a test",
|
|
"max_length": 10,
|
|
"model_max_length": 20,
|
|
"new_prompt_length": 5,
|
|
}
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_flan_prompt_handler_resize():
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
|
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
|
"prompt_length": 17,
|
|
"resized_prompt": "This is a prompt that will be re",
|
|
"max_length": 10,
|
|
"model_max_length": 20,
|
|
"new_prompt_length": 10,
|
|
}
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_flan_prompt_handler_empty_string():
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
|
assert handler("") == {
|
|
"prompt_length": 0,
|
|
"resized_prompt": "",
|
|
"max_length": 10,
|
|
"model_max_length": 20,
|
|
"new_prompt_length": 0,
|
|
}
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_flan_prompt_handler_none():
|
|
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
|
assert handler(None) == {
|
|
"prompt_length": 0,
|
|
"resized_prompt": None,
|
|
"max_length": 10,
|
|
"model_max_length": 20,
|
|
"new_prompt_length": 0,
|
|
}
|
|
|
|
|
|
@pytest.mark.unit
|
|
@patch("builtins.print")
|
|
def test_anthropic_token_streaming_handler(mock_print):
|
|
handler = AnthropicTokenStreamingHandler(DefaultTokenStreamingHandler())
|
|
|
|
res = handler(" This")
|
|
assert res == " This"
|
|
mock_print.assert_called_with(" This", flush=True, end="")
|
|
|
|
res = handler(" This is a new")
|
|
assert res == " is a new"
|
|
mock_print.assert_called_with(" is a new", flush=True, end="")
|
|
|
|
res = handler(" This is a new token")
|
|
assert res == " token"
|
|
mock_print.assert_called_with(" token", flush=True, end="")
|
|
|
|
res = handler("And now")
|
|
assert res == "And now"
|
|
mock_print.assert_called_with("And now", flush=True, end="")
|
|
|
|
res = handler("And now something completely different")
|
|
assert res == " something completely different"
|
|
mock_print.assert_called_with(" something completely different", flush=True, end="")
|