mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
feat: HFInferenceEndpointInvocationLayer streaming support (#4819)
* HFInferenceEndpointInvocationLayer streaming support * Small fixes * Add unit test * PR feedback * Alphabetically sort params * Convert PromptNode tests to HFInferenceEndpointInvocationLayer invoke tests * Rewrite streaming with sseclient * More PR updates * Implement and test _ensure_token_limit * Further optimize DefaultPromptHandler * Fix CohereInvocationLayer mistypes * PR feedback * Break up unit tests, simplify * Simplify unit tests even further * PR feedback on unit test simplification * Proper code identation under patch context manager * More unit tests, slight adjustments * Remove unrelated CohereInvocationLayer change This reverts commit 82337151e8328d982f738e5da9129ff99350ea0c. * Revert "Further optimize DefaultPromptHandler" This reverts commit 606a761b6e3333f27df51a304cfbd1906c806e05. * lg update mostly full stops at the end of docstrings --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Silvano Cerza <silvanocerza@gmail.com> Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
parent
9398183447
commit
068a967e5b
@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Dict, Union, List, Any
|
||||
from typing import Optional, Dict, Union, List, Any, Callable
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import sseclient
|
||||
from transformers.pipelines import get_task
|
||||
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
||||
@ -12,7 +13,12 @@ from haystack.errors import (
|
||||
HuggingFaceInferenceUnauthorizedError,
|
||||
HuggingFaceInferenceError,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer import (
|
||||
PromptModelInvocationLayer,
|
||||
TokenStreamingHandler,
|
||||
DefaultTokenStreamingHandler,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
|
||||
from haystack.utils.requests import request_with_retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -44,17 +50,13 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
|
||||
be found in your Hugging Face account [settings](https://huggingface.co/settings/tokens)
|
||||
"""
|
||||
super().__init__(model_name_or_path)
|
||||
self.prompt_preprocessors: Dict[str, Callable] = {}
|
||||
valid_api_key = isinstance(api_key, str) and api_key
|
||||
if not valid_api_key:
|
||||
raise ValueError(
|
||||
f"api_key {api_key} must be a valid Hugging Face token. "
|
||||
f"Your token is available in your Hugging Face settings page."
|
||||
)
|
||||
valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path
|
||||
if not valid_model_name_or_path:
|
||||
raise ValueError(
|
||||
f"model_name_or_path {model_name_or_path} must be a valid Hugging Face inference endpoint URL."
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.max_length = max_length
|
||||
|
||||
@ -63,18 +65,51 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
|
||||
self.model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
"top_k",
|
||||
"top_p",
|
||||
"temperature",
|
||||
"repetition_penalty",
|
||||
"best_of",
|
||||
"details",
|
||||
"do_sample",
|
||||
"max_new_tokens",
|
||||
"max_time",
|
||||
"return_full_text",
|
||||
"model_max_length",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
"repetition_penalty",
|
||||
"return_full_text",
|
||||
"seed",
|
||||
"stream",
|
||||
"stream_handler",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"truncate",
|
||||
"typical_p",
|
||||
"watermark",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
self.prompt_preprocessors["oasst"] = lambda prompt: f"<|prompter|>{prompt}<|endoftext|><|assistant|>"
|
||||
|
||||
# we pop the model_max_length from the model_input_kwargs as it is not sent to the model
|
||||
# but used to truncate the prompt if needed
|
||||
model_max_length = self.model_input_kwargs.pop("model_max_length", 1024)
|
||||
|
||||
if HFInferenceEndpointInvocationLayer.is_inference_endpoint(model_name_or_path):
|
||||
# as we are using the deployed HF inference endpoint, we don't know the model name
|
||||
# we'll use gpt2 BPE tokenizer for prompt length calculation
|
||||
self.prompt_handler = DefaultPromptHandler(
|
||||
model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
|
||||
)
|
||||
else:
|
||||
self.prompt_handler = DefaultPromptHandler(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_max_length=model_max_length,
|
||||
max_length=self.max_length or 100,
|
||||
)
|
||||
|
||||
def preprocess_prompt(self, prompt: str):
|
||||
for key, prompt_preprocessor in self.prompt_preprocessors.items():
|
||||
if key in self.model_name_or_path:
|
||||
return prompt_preprocessor(prompt)
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@ -102,58 +137,99 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
stop_words = kwargs.pop("stop_words", None)
|
||||
|
||||
prompt = self.preprocess_prompt(prompt)
|
||||
stop_words = kwargs.pop("stop_words", None) or []
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
|
||||
if "max_new_tokens" not in kwargs_with_defaults:
|
||||
kwargs_with_defaults["max_new_tokens"] = self.max_length
|
||||
|
||||
if "top_k" in kwargs:
|
||||
top_k = kwargs.pop("top_k")
|
||||
kwargs["num_return_sequences"] = top_k
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
# see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
accepted_params = [
|
||||
"top_p",
|
||||
"top_k",
|
||||
"temperature",
|
||||
"repetition_penalty",
|
||||
"max_new_tokens",
|
||||
"max_time",
|
||||
"return_full_text",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
]
|
||||
params = {key: kwargs_with_defaults.get(key) for key in accepted_params if key in kwargs_with_defaults}
|
||||
generated_texts = self._post(data={"inputs": prompt, "parameters": params}, **kwargs)
|
||||
if stop_words:
|
||||
for idx, _ in enumerate(generated_texts):
|
||||
earliest_stop_word_idx = len(generated_texts[idx])
|
||||
for stop_word in stop_words:
|
||||
stop_word_idx = generated_texts[idx].find(stop_word)
|
||||
if stop_word_idx != -1:
|
||||
earliest_stop_word_idx = min(earliest_stop_word_idx, stop_word_idx)
|
||||
generated_texts[idx] = generated_texts[idx][:earliest_stop_word_idx]
|
||||
|
||||
# either stream is True (will use default handler) or stream_handler is provided
|
||||
stream = (
|
||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||
)
|
||||
|
||||
# see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
params = {
|
||||
"best_of": kwargs_with_defaults.get("best_of", None),
|
||||
"details": kwargs_with_defaults.get("details", True),
|
||||
"do_sample": kwargs_with_defaults.get("do_sample", False),
|
||||
"max_new_tokens": kwargs_with_defaults.get("max_new_tokens", self.max_length),
|
||||
"max_time": kwargs_with_defaults.get("max_time", None),
|
||||
"num_return_sequences": kwargs_with_defaults.get("num_return_sequences", None),
|
||||
"repetition_penalty": kwargs_with_defaults.get("repetition_penalty", None),
|
||||
"return_full_text": kwargs_with_defaults.get("return_full_text", False),
|
||||
"seed": kwargs_with_defaults.get("seed", None),
|
||||
"stop": kwargs_with_defaults.get("stop", stop_words),
|
||||
"temperature": kwargs_with_defaults.get("temperature", None),
|
||||
"top_k": kwargs_with_defaults.get("top_k", None),
|
||||
"top_p": kwargs_with_defaults.get("top_p", None),
|
||||
"truncate": kwargs_with_defaults.get("truncate", None),
|
||||
"typical_p": kwargs_with_defaults.get("typical_p", None),
|
||||
"watermark": kwargs_with_defaults.get("watermark", False),
|
||||
}
|
||||
response: requests.Response = self._post(
|
||||
data={"inputs": prompt, "parameters": params, "stream": stream}, stream=stream
|
||||
)
|
||||
if stream:
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
generated_texts = self._process_streaming_response(response, handler, stop_words)
|
||||
else:
|
||||
output = json.loads(response.text)
|
||||
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
|
||||
return generated_texts
|
||||
|
||||
def _process_streaming_response(
|
||||
self, response: requests.Response, stream_handler: TokenStreamingHandler, stop_words: List[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Stream the response and invoke the stream_handler on each token.
|
||||
|
||||
:param response: The response object from the server.
|
||||
:param stream_handler: The handler to invoke on each token.
|
||||
:param stop_words: The stop words to ignore.
|
||||
"""
|
||||
client = sseclient.SSEClient(response)
|
||||
tokens: List[str] = []
|
||||
try:
|
||||
for event in client.events():
|
||||
if event.data != TokenStreamingHandler.DONE_MARKER:
|
||||
event_data = json.loads(event.data)
|
||||
token: Optional[str] = self._extract_token(event_data)
|
||||
# if valid token and not a stop words (we don't want to return stop words)
|
||||
if token and token.strip() not in stop_words:
|
||||
tokens.append(stream_handler(token, event_data=event_data))
|
||||
finally:
|
||||
client.close()
|
||||
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||
|
||||
def _extract_token(self, event_data: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Extract the token from the event data. If the token is a special token, return None.
|
||||
param event_data: Event data from the streaming response.
|
||||
"""
|
||||
# extract token from event data and only consider non-special tokens
|
||||
return event_data["token"]["text"] if not event_data["token"]["special"] else None
|
||||
|
||||
def _post(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
attempts: int = HF_RETRIES,
|
||||
status_codes: Optional[List[int]] = None,
|
||||
timeout: float = HF_TIMEOUT,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST invocation.
|
||||
:param data: The data to be sent to the model.
|
||||
:param stream: Whether to stream the response.
|
||||
:param attempts: The number of attempts to make.
|
||||
:param status_codes: The status codes to retry on.
|
||||
:param timeout: The timeout for the request.
|
||||
:return: The responses are being returned.
|
||||
"""
|
||||
generated_texts: List[str] = []
|
||||
response: requests.Response
|
||||
if status_codes is None:
|
||||
status_codes = [429]
|
||||
try:
|
||||
@ -165,9 +241,8 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
output = json.loads(response.text)
|
||||
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
|
||||
except requests.HTTPError as err:
|
||||
res = err.response
|
||||
if res.status_code == 429:
|
||||
@ -179,11 +254,22 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
|
||||
f"HuggingFace Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
|
||||
status_code=res.status_code,
|
||||
)
|
||||
return generated_texts
|
||||
return response
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
# TODO: new implementation incoming for all layers, let's omit this for now
|
||||
return prompt
|
||||
# the prompt for this model will be of the type str
|
||||
resize_info = self.prompt_handler(prompt) # type: ignore
|
||||
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
|
||||
logger.warning(
|
||||
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
|
||||
"answer length (%s tokens) fit within the max token limit (%s tokens). "
|
||||
"Shorten the prompt to prevent it from being cut off.",
|
||||
resize_info["prompt_length"],
|
||||
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
|
||||
resize_info["max_length"],
|
||||
resize_info["model_max_length"],
|
||||
)
|
||||
return str(resize_info["resized_prompt"])
|
||||
|
||||
@staticmethod
|
||||
def is_inference_endpoint(model_name_or_path: str) -> bool:
|
||||
|
||||
355
test/prompt/invocation_layer/test_hugging_face_inference.py
Normal file
355
test/prompt/invocation_layer/test_hugging_face_inference.py
Normal file
@ -0,0 +1,355 @@
|
||||
import logging
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_constructor():
|
||||
"""
|
||||
Test that the default constructor sets the correct values
|
||||
"""
|
||||
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
|
||||
|
||||
assert layer.api_key == "some_fake_key"
|
||||
assert layer.max_length == 100
|
||||
assert layer.model_input_kwargs == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_kwargs():
|
||||
"""
|
||||
Test that model_kwargs are correctly set in the constructor
|
||||
and that model_kwargs_rejected are correctly filtered out
|
||||
"""
|
||||
model_kwargs = {"temperature": 0.7, "do_sample": True, "stream": True}
|
||||
model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1}
|
||||
|
||||
layer = HFInferenceEndpointInvocationLayer(
|
||||
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected
|
||||
)
|
||||
assert "temperature" in layer.model_input_kwargs
|
||||
assert "do_sample" in layer.model_input_kwargs
|
||||
assert "stream" in layer.model_input_kwargs
|
||||
assert "fake_param" not in layer.model_input_kwargs
|
||||
assert "another_fake_param" not in layer.model_input_kwargs
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_set_model_max_length():
|
||||
"""
|
||||
Test that model max length is set correctly
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(
|
||||
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", model_max_length=2048
|
||||
)
|
||||
assert layer.prompt_handler.model_max_length == 2048
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_url():
|
||||
"""
|
||||
Test that the url is correctly set in the constructor
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
|
||||
assert layer.url == "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
|
||||
|
||||
layer = HFInferenceEndpointInvocationLayer(
|
||||
model_name_or_path="https://23445.us-east-1.aws.endpoints.huggingface.cloud", api_key="some_fake_key"
|
||||
)
|
||||
|
||||
assert layer.url == "https://23445.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_no_kwargs():
|
||||
"""
|
||||
Test that invoke raises an error if no prompt is provided
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
|
||||
with pytest.raises(ValueError) as e:
|
||||
layer.invoke()
|
||||
assert e.match("No prompt provided.")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_stop_words():
|
||||
"""
|
||||
Test stop words are correctly passed to HTTP POST request
|
||||
"""
|
||||
stop_words = ["but", "not", "bye"]
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
assert "stop" in called_kwargs["data"]["parameters"]
|
||||
assert called_kwargs["data"]["parameters"]["stop"] == stop_words
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_streaming_stream_param_in_constructor(stream):
|
||||
"""
|
||||
Test stream parameter is correctly passed to HTTP POST request via constructor
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(
|
||||
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream=stream
|
||||
)
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
assert called_kwargs["stream"] == stream
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_streaming_stream_param_in_method(stream):
|
||||
"""
|
||||
Test stream parameter is correctly passed to HTTP POST request via method
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
layer.invoke(prompt="Tell me hello", stream=stream)
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# Assert that the 'stream' parameter passed to _post is the same as the one used in layer.invoke()
|
||||
assert called_kwargs["stream"] == stream
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_in_constructor():
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
|
||||
"""
|
||||
stream_handler = DefaultTokenStreamingHandler()
|
||||
layer = HFInferenceEndpointInvocationLayer(
|
||||
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream_handler=stream_handler
|
||||
)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post, unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
|
||||
) as mock_post_stream:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
assert called_kwargs["stream"]
|
||||
|
||||
# stream_handler is passed as an instance of TokenStreamingHandler
|
||||
called_args, called_kwargs = mock_post_stream.call_args
|
||||
assert isinstance(called_args[1], TokenStreamingHandler)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_no_stream_handler_param_in_constructor():
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# but it is False if stream_handler is None
|
||||
assert not called_kwargs["stream"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_in_method():
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via method
|
||||
"""
|
||||
stream_handler = DefaultTokenStreamingHandler()
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post, unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
|
||||
) as mock_post_stream:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is correctly passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
assert called_kwargs["stream"]
|
||||
|
||||
called_args, called_kwargs = mock_post_stream.call_args
|
||||
assert isinstance(called_args[1], TokenStreamingHandler)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_no_stream_handler_param_in_method():
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via method
|
||||
"""
|
||||
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stream_handler=None)
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always correctly passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
assert not called_kwargs["stream"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
|
||||
)
|
||||
def test_ensure_token_limit_no_resize(model_name_or_path):
|
||||
# In this test case we assume that no prompt resizing is needed for all models
|
||||
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=100)
|
||||
|
||||
# Define prompt and expected results
|
||||
prompt = "This is a test prompt."
|
||||
|
||||
resized_prompt = handler._ensure_token_limit(prompt)
|
||||
|
||||
# Verify the results
|
||||
assert resized_prompt == prompt
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
|
||||
)
|
||||
def test_ensure_token_limit_resize(caplog, model_name_or_path):
|
||||
# In this test case we assume prompt resizing is needed for all models
|
||||
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=5, model_max_length=10)
|
||||
|
||||
# Define prompt and expected results
|
||||
prompt = "This is a test prompt that will be resized because model_max_length is 10 and max_length is 5."
|
||||
with caplog.at_level(logging.WARN):
|
||||
resized_prompt = handler._ensure_token_limit(prompt)
|
||||
assert "The prompt has been truncated" in caplog.text
|
||||
|
||||
# Verify the results
|
||||
assert resized_prompt != prompt
|
||||
assert (
|
||||
"This is a test" in resized_prompt
|
||||
and "because model_max_length is 10 and max_length is 5" not in resized_prompt
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_oasst_prompt_preprocessing():
|
||||
model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
|
||||
|
||||
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
|
||||
) as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
result = layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert result == ["Hello"]
|
||||
assert mock_post.called
|
||||
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
|
||||
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_key():
|
||||
with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
|
||||
layer = HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_model():
|
||||
with pytest.raises(ValueError, match="cannot be None or empty string"):
|
||||
layer = HFInferenceEndpointInvocationLayer("fake_api", "")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports():
|
||||
"""
|
||||
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
|
||||
"""
|
||||
# doesn't support fake model
|
||||
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
|
||||
|
||||
# supports google/flan-t5-xxl with api_key
|
||||
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
|
||||
|
||||
# doesn't support google/flan-t5-xxl without api_key
|
||||
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl")
|
||||
|
||||
# supports HF Inference Endpoint with api_key
|
||||
assert HFInferenceEndpointInvocationLayer.supports(
|
||||
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
|
||||
)
|
||||
@ -58,3 +58,21 @@ def test_flan_prompt_handler():
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 10,
|
||||
}
|
||||
|
||||
# test corner cases
|
||||
assert handler("") == {
|
||||
"prompt_length": 0,
|
||||
"resized_prompt": "",
|
||||
"max_length": 10,
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 0,
|
||||
}
|
||||
|
||||
# test corner case
|
||||
assert handler(None) == {
|
||||
"prompt_length": 0,
|
||||
"resized_prompt": None,
|
||||
"max_length": 10,
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 0,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user