haystack/test/prompt/invocation_layer/test_hugging_face_inference.py
MichelBartels fd350bbb8f
fix: Run HFLocalInvocationLayer.supports even if inference packages are not installed (#5308)
---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
2023-07-13 12:52:56 +02:00

376 lines
14 KiB
Python

import logging
import unittest
from unittest.mock import patch, MagicMock
import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
@pytest.fixture
def mock_get_task():
# mock get_task function
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
mock_get_task.return_value = "text2text-generation"
yield mock_get_task
@pytest.fixture
def mock_get_task_invalid():
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
mock_get_task.return_value = "some-nonexistent-type"
yield mock_get_task
@pytest.mark.unit
def test_default_constructor(mock_auto_tokenizer):
"""
Test that the default constructor sets the correct values
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
assert layer.api_key == "some_fake_key"
assert layer.max_length == 100
assert layer.model_input_kwargs == {}
@pytest.mark.unit
def test_constructor_with_model_kwargs(mock_auto_tokenizer):
"""
Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out
"""
model_kwargs = {"temperature": 0.7, "do_sample": True, "stream": True}
model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1}
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected
)
assert "temperature" in layer.model_input_kwargs
assert "do_sample" in layer.model_input_kwargs
assert "stream" in layer.model_input_kwargs
assert "fake_param" not in layer.model_input_kwargs
assert "another_fake_param" not in layer.model_input_kwargs
@pytest.mark.unit
def test_set_model_max_length(mock_auto_tokenizer):
"""
Test that model max length is set correctly
"""
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", model_max_length=2048
)
assert layer.prompt_handler.model_max_length == 2048
@pytest.mark.unit
def test_url(mock_auto_tokenizer):
"""
Test that the url is correctly set in the constructor
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
assert layer.url == "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="https://23445.us-east-1.aws.endpoints.huggingface.cloud", api_key="some_fake_key"
)
assert layer.url == "https://23445.us-east-1.aws.endpoints.huggingface.cloud"
@pytest.mark.unit
def test_invoke_with_no_kwargs(mock_auto_tokenizer):
"""
Test that invoke raises an error if no prompt is provided
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
with pytest.raises(ValueError) as e:
layer.invoke()
assert e.match("No prompt provided.")
@pytest.mark.unit
def test_invoke_with_stop_words(mock_auto_tokenizer):
"""
Test stop words are correctly passed to HTTP POST request
"""
stop_words = ["but", "not", "bye"]
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
_, called_kwargs = mock_post.call_args
assert "stop" in called_kwargs["data"]["parameters"]
assert called_kwargs["data"]["parameters"]["stop"] == stop_words
@pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_constructor(mock_auto_tokenizer, stream):
"""
Test stream parameter is correctly passed to HTTP POST request via constructor
"""
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream=stream
)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert called_kwargs["stream"] == stream
@pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_method(mock_auto_tokenizer, stream):
"""
Test stream parameter is correctly passed to HTTP POST request via method
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Assert that the 'stream' parameter passed to _post is the same as the one used in layer.invoke()
assert called_kwargs["stream"] == stream
@pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
"""
stream_handler = DefaultTokenStreamingHandler()
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream_handler=stream_handler
)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post, unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert called_kwargs["stream"]
# stream_handler is passed as an instance of TokenStreamingHandler
called_args, _ = mock_post_stream.call_args
assert isinstance(called_args[1], TokenStreamingHandler)
@pytest.mark.unit
def test_streaming_no_stream_handler_param_in_constructor(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# but it is False if stream_handler is None
assert not called_kwargs["stream"]
@pytest.mark.unit
def test_streaming_stream_handler_param_in_method(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via method
"""
stream_handler = DefaultTokenStreamingHandler()
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post, unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is correctly passed to _post
assert "stream" in called_kwargs
assert called_kwargs["stream"]
called_args, called_kwargs = mock_post_stream.call_args
assert isinstance(called_args[1], TokenStreamingHandler)
@pytest.mark.unit
def test_streaming_no_stream_handler_param_in_method(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via method
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stream_handler=None)
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always correctly passed to _post
assert "stream" in called_kwargs
assert not called_kwargs["stream"]
@pytest.mark.integration
@pytest.mark.parametrize(
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
)
def test_ensure_token_limit_no_resize(model_name_or_path):
# In this test case we assume that no prompt resizing is needed for all models
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=100)
# Define prompt and expected results
prompt = "This is a test prompt."
resized_prompt = handler._ensure_token_limit(prompt)
# Verify the results
assert resized_prompt == prompt
@pytest.mark.integration
@pytest.mark.parametrize(
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
)
def test_ensure_token_limit_resize(caplog, model_name_or_path):
# In this test case we assume prompt resizing is needed for all models
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=5, model_max_length=10)
# Define prompt and expected results
prompt = "This is a test prompt that will be resized because model_max_length is 10 and max_length is 5."
with caplog.at_level(logging.WARN):
resized_prompt = handler._ensure_token_limit(prompt)
assert "The prompt has been truncated" in caplog.text
# Verify the results
assert resized_prompt != prompt
assert (
"This is a test" in resized_prompt
and "because model_max_length is 10 and max_length is 5" not in resized_prompt
)
@pytest.mark.unit
def test_oasst_prompt_preprocessing(mock_auto_tokenizer):
model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
result = layer.invoke(prompt="Tell me hello")
assert result == ["Hello"]
assert mock_post.called
_, called_kwargs = mock_post.call_args
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
@pytest.mark.unit
def test_invalid_key():
with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
@pytest.mark.unit
def test_invalid_model():
with pytest.raises(ValueError, match="cannot be None or empty string"):
HFInferenceEndpointInvocationLayer("fake_api", "")
@pytest.mark.unit
def test_supports(mock_get_task):
"""
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
"""
# supports google/flan-t5-xxl with api_key
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
# supports HF Inference Endpoint with api_key
assert HFInferenceEndpointInvocationLayer.supports(
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
)
@pytest.mark.unit
def test_supports_not(mock_get_task_invalid):
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
# doesn't support google/flan-t5-xxl without api_key
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl")
# doesn't support HF Inference Endpoint without proper api_key
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="")
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key=None)
# doesn't support model if hf server timeout
mock_get_task.side_effect = RuntimeError
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")