mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-11 10:07:50 +00:00
376 lines
14 KiB
Python
376 lines
14 KiB
Python
import logging
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
|
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_get_task():
|
|
# mock get_task function
|
|
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
|
|
mock_get_task.return_value = "text2text-generation"
|
|
yield mock_get_task
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_get_task_invalid():
|
|
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
|
|
mock_get_task.return_value = "some-nonexistent-type"
|
|
yield mock_get_task
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_default_constructor(mock_auto_tokenizer):
|
|
"""
|
|
Test that the default constructor sets the correct values
|
|
"""
|
|
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
|
|
|
|
assert layer.api_key == "some_fake_key"
|
|
assert layer.max_length == 100
|
|
assert layer.model_input_kwargs == {}
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_constructor_with_model_kwargs(mock_auto_tokenizer):
|
|
"""
|
|
Test that model_kwargs are correctly set in the constructor
|
|
and that model_kwargs_rejected are correctly filtered out
|
|
"""
|
|
model_kwargs = {"temperature": 0.7, "do_sample": True, "stream": True}
|
|
model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1}
|
|
|
|
layer = HFInferenceEndpointInvocationLayer(
|
|
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected
|
|
)
|
|
assert "temperature" in layer.model_input_kwargs
|
|
assert "do_sample" in layer.model_input_kwargs
|
|
assert "stream" in layer.model_input_kwargs
|
|
assert "fake_param" not in layer.model_input_kwargs
|
|
assert "another_fake_param" not in layer.model_input_kwargs
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_set_model_max_length(mock_auto_tokenizer):
|
|
"""
|
|
Test that model max length is set correctly
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(
|
|
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", model_max_length=2048
|
|
)
|
|
assert layer.prompt_handler.model_max_length == 2048
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_url(mock_auto_tokenizer):
|
|
"""
|
|
Test that the url is correctly set in the constructor
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
|
|
assert layer.url == "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
|
|
|
|
layer = HFInferenceEndpointInvocationLayer(
|
|
model_name_or_path="https://23445.us-east-1.aws.endpoints.huggingface.cloud", api_key="some_fake_key"
|
|
)
|
|
|
|
assert layer.url == "https://23445.us-east-1.aws.endpoints.huggingface.cloud"
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_invoke_with_no_kwargs(mock_auto_tokenizer):
|
|
"""
|
|
Test that invoke raises an error if no prompt is provided
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
|
|
with pytest.raises(ValueError) as e:
|
|
layer.invoke()
|
|
assert e.match("No prompt provided.")
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_invoke_with_stop_words(mock_auto_tokenizer):
|
|
"""
|
|
Test stop words are correctly passed to HTTP POST request
|
|
"""
|
|
stop_words = ["but", "not", "bye"]
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
|
|
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
|
|
|
|
assert mock_post.called
|
|
|
|
# Check if stop_words are passed to _post as stop parameter
|
|
_, called_kwargs = mock_post.call_args
|
|
assert "stop" in called_kwargs["data"]["parameters"]
|
|
assert called_kwargs["data"]["parameters"]["stop"] == stop_words
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.parametrize("stream", [True, False])
|
|
def test_streaming_stream_param_in_constructor(mock_auto_tokenizer, stream):
|
|
"""
|
|
Test stream parameter is correctly passed to HTTP POST request via constructor
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(
|
|
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream=stream
|
|
)
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
layer.invoke(prompt="Tell me hello")
|
|
|
|
assert mock_post.called
|
|
_, called_kwargs = mock_post.call_args
|
|
|
|
# stream is always passed to _post
|
|
assert "stream" in called_kwargs
|
|
|
|
assert called_kwargs["stream"] == stream
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.parametrize("stream", [True, False])
|
|
def test_streaming_stream_param_in_method(mock_auto_tokenizer, stream):
|
|
"""
|
|
Test stream parameter is correctly passed to HTTP POST request via method
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
layer.invoke(prompt="Tell me hello", stream=stream)
|
|
|
|
assert mock_post.called
|
|
_, called_kwargs = mock_post.call_args
|
|
|
|
# stream is always passed to _post
|
|
assert "stream" in called_kwargs
|
|
|
|
# Assert that the 'stream' parameter passed to _post is the same as the one used in layer.invoke()
|
|
assert called_kwargs["stream"] == stream
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_streaming_stream_handler_param_in_constructor(mock_auto_tokenizer):
|
|
"""
|
|
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
|
|
"""
|
|
stream_handler = DefaultTokenStreamingHandler()
|
|
layer = HFInferenceEndpointInvocationLayer(
|
|
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream_handler=stream_handler
|
|
)
|
|
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post, unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
|
|
) as mock_post_stream:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
layer.invoke(prompt="Tell me hello")
|
|
|
|
assert mock_post.called
|
|
_, called_kwargs = mock_post.call_args
|
|
|
|
# stream is always passed to _post
|
|
assert "stream" in called_kwargs
|
|
|
|
assert called_kwargs["stream"]
|
|
|
|
# stream_handler is passed as an instance of TokenStreamingHandler
|
|
called_args, _ = mock_post_stream.call_args
|
|
assert isinstance(called_args[1], TokenStreamingHandler)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_streaming_no_stream_handler_param_in_constructor(mock_auto_tokenizer):
|
|
"""
|
|
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
|
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
layer.invoke(prompt="Tell me hello")
|
|
|
|
assert mock_post.called
|
|
_, called_kwargs = mock_post.call_args
|
|
|
|
# stream is always passed to _post
|
|
assert "stream" in called_kwargs
|
|
|
|
# but it is False if stream_handler is None
|
|
assert not called_kwargs["stream"]
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_streaming_stream_handler_param_in_method(mock_auto_tokenizer):
|
|
"""
|
|
Test stream_handler parameter is correctly passed to HTTP POST request via method
|
|
"""
|
|
stream_handler = DefaultTokenStreamingHandler()
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
|
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post, unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
|
|
) as mock_post_stream:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
|
|
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
|
|
|
|
assert mock_post.called
|
|
called_args, called_kwargs = mock_post.call_args
|
|
|
|
# stream is correctly passed to _post
|
|
assert "stream" in called_kwargs
|
|
assert called_kwargs["stream"]
|
|
|
|
called_args, called_kwargs = mock_post_stream.call_args
|
|
assert isinstance(called_args[1], TokenStreamingHandler)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_streaming_no_stream_handler_param_in_method(mock_auto_tokenizer):
|
|
"""
|
|
Test stream_handler parameter is correctly passed to HTTP POST request via method
|
|
"""
|
|
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
|
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
|
|
layer.invoke(prompt="Tell me hello", stream_handler=None)
|
|
|
|
assert mock_post.called
|
|
|
|
_, called_kwargs = mock_post.call_args
|
|
|
|
# stream is always correctly passed to _post
|
|
assert "stream" in called_kwargs
|
|
assert not called_kwargs["stream"]
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
|
|
)
|
|
def test_ensure_token_limit_no_resize(model_name_or_path):
|
|
# In this test case we assume that no prompt resizing is needed for all models
|
|
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=100)
|
|
|
|
# Define prompt and expected results
|
|
prompt = "This is a test prompt."
|
|
|
|
resized_prompt = handler._ensure_token_limit(prompt)
|
|
|
|
# Verify the results
|
|
assert resized_prompt == prompt
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
|
|
)
|
|
def test_ensure_token_limit_resize(caplog, model_name_or_path):
|
|
# In this test case we assume prompt resizing is needed for all models
|
|
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=5, model_max_length=10)
|
|
|
|
# Define prompt and expected results
|
|
prompt = "This is a test prompt that will be resized because model_max_length is 10 and max_length is 5."
|
|
with caplog.at_level(logging.WARN):
|
|
resized_prompt = handler._ensure_token_limit(prompt)
|
|
assert "The prompt has been truncated" in caplog.text
|
|
|
|
# Verify the results
|
|
assert resized_prompt != prompt
|
|
assert (
|
|
"This is a test" in resized_prompt
|
|
and "because model_max_length is 10 and max_length is 5" not in resized_prompt
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_oasst_prompt_preprocessing(mock_auto_tokenizer):
|
|
model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
|
|
|
|
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
|
|
with unittest.mock.patch(
|
|
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
|
) as mock_post:
|
|
# Mock the response, need to return a list of dicts
|
|
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
|
result = layer.invoke(prompt="Tell me hello")
|
|
|
|
assert result == ["Hello"]
|
|
assert mock_post.called
|
|
|
|
_, called_kwargs = mock_post.call_args
|
|
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
|
|
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_invalid_key():
|
|
with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
|
|
HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_invalid_model():
|
|
with pytest.raises(ValueError, match="cannot be None or empty string"):
|
|
HFInferenceEndpointInvocationLayer("fake_api", "")
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_supports(mock_get_task):
|
|
"""
|
|
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
|
|
"""
|
|
|
|
# supports google/flan-t5-xxl with api_key
|
|
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
|
|
|
|
# supports HF Inference Endpoint with api_key
|
|
assert HFInferenceEndpointInvocationLayer.supports(
|
|
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_supports_not(mock_get_task_invalid):
|
|
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
|
|
|
|
# doesn't support google/flan-t5-xxl without api_key
|
|
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl")
|
|
|
|
# doesn't support HF Inference Endpoint without proper api_key
|
|
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="")
|
|
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key=None)
|
|
|
|
# doesn't support model if hf server timeout
|
|
mock_get_task.side_effect = RuntimeError
|
|
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
|