haystack/test/prompt/test_handlers.py
Vladimir Blagojevic 068a967e5b
feat: HFInferenceEndpointInvocationLayer streaming support (#4819)
* HFInferenceEndpointInvocationLayer streaming support

* Small fixes

* Add unit test

* PR feedback

* Alphabetically sort params

* Convert PromptNode tests to HFInferenceEndpointInvocationLayer invoke tests

* Rewrite streaming with sseclient

* More PR updates

* Implement and test _ensure_token_limit

* Further optimize DefaultPromptHandler

* Fix CohereInvocationLayer mistypes

* PR feedback

* Break up unit tests, simplify

* Simplify unit tests even further

* PR feedback on unit test simplification

* Proper code identation under patch context manager

* More unit tests, slight adjustments

* Remove unrelated CohereInvocationLayer change

This reverts commit 82337151e8328d982f738e5da9129ff99350ea0c.

* Revert "Further optimize DefaultPromptHandler"

This reverts commit 606a761b6e3333f27df51a304cfbd1906c806e05.

* lg update

mostly full stops at the end of docstrings

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
Co-authored-by: Darja Fokina <daria.f93@gmail.com>
2023-05-22 14:45:53 +02:00

79 lines
2.3 KiB
Python

import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
@pytest.mark.integration
def test_prompt_handler_basics():
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
assert callable(handler)
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
assert handler.max_length == 100
@pytest.mark.integration
def test_gpt2_prompt_handler():
# test gpt2 BPE based tokenizer
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == {
"prompt_length": 4,
"resized_prompt": "This is a test",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 4,
}
# test resize
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 15,
"resized_prompt": "This is a prompt that will be resized because",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 10,
}
@pytest.mark.integration
def test_flan_prompt_handler():
# test google/flan-t5-xxl tokenizer
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == {
"prompt_length": 5,
"resized_prompt": "This is a test",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 5,
}
# test resize
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 17,
"resized_prompt": "This is a prompt that will be re",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 10,
}
# test corner cases
assert handler("") == {
"prompt_length": 0,
"resized_prompt": "",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 0,
}
# test corner case
assert handler(None) == {
"prompt_length": 0,
"resized_prompt": None,
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 0,
}