feat: HFInferenceEndpointInvocationLayer streaming support (#4819)

* HFInferenceEndpointInvocationLayer streaming support

* Small fixes

* Add unit test

* PR feedback

* Alphabetically sort params

* Convert PromptNode tests to HFInferenceEndpointInvocationLayer invoke tests

* Rewrite streaming with sseclient

* More PR updates

* Implement and test _ensure_token_limit

* Further optimize DefaultPromptHandler

* Fix CohereInvocationLayer mistypes

* PR feedback

* Break up unit tests, simplify

* Simplify unit tests even further

* PR feedback on unit test simplification

* Proper code identation under patch context manager

* More unit tests, slight adjustments

* Remove unrelated CohereInvocationLayer change

This reverts commit 82337151e8328d982f738e5da9129ff99350ea0c.

* Revert "Further optimize DefaultPromptHandler"

This reverts commit 606a761b6e3333f27df51a304cfbd1906c806e05.

* lg update

mostly full stops at the end of docstrings

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-05-22 14:45:53 +02:00 committed by GitHub
parent 9398183447
commit 068a967e5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 508 additions and 49 deletions

View File

@ -1,9 +1,10 @@
import json
import os
from typing import Optional, Dict, Union, List, Any
from typing import Optional, Dict, Union, List, Any, Callable
import logging
import requests
import sseclient
from transformers.pipelines import get_task
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
@ -12,7 +13,12 @@ from haystack.errors import (
HuggingFaceInferenceUnauthorizedError,
HuggingFaceInferenceError,
)
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer import (
PromptModelInvocationLayer,
TokenStreamingHandler,
DefaultTokenStreamingHandler,
)
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
from haystack.utils.requests import request_with_retry
logger = logging.getLogger(__name__)
@ -44,17 +50,13 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
be found in your Hugging Face account [settings](https://huggingface.co/settings/tokens)
"""
super().__init__(model_name_or_path)
self.prompt_preprocessors: Dict[str, Callable] = {}
valid_api_key = isinstance(api_key, str) and api_key
if not valid_api_key:
raise ValueError(
f"api_key {api_key} must be a valid Hugging Face token. "
f"Your token is available in your Hugging Face settings page."
)
valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path
if not valid_model_name_or_path:
raise ValueError(
f"model_name_or_path {model_name_or_path} must be a valid Hugging Face inference endpoint URL."
)
self.api_key = api_key
self.max_length = max_length
@ -63,18 +65,51 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
self.model_input_kwargs = {
key: kwargs[key]
for key in [
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"best_of",
"details",
"do_sample",
"max_new_tokens",
"max_time",
"return_full_text",
"model_max_length",
"num_return_sequences",
"do_sample",
"repetition_penalty",
"return_full_text",
"seed",
"stream",
"stream_handler",
"temperature",
"top_k",
"top_p",
"truncate",
"typical_p",
"watermark",
]
if key in kwargs
}
self.prompt_preprocessors["oasst"] = lambda prompt: f"<|prompter|>{prompt}<|endoftext|><|assistant|>"
# we pop the model_max_length from the model_input_kwargs as it is not sent to the model
# but used to truncate the prompt if needed
model_max_length = self.model_input_kwargs.pop("model_max_length", 1024)
if HFInferenceEndpointInvocationLayer.is_inference_endpoint(model_name_or_path):
# as we are using the deployed HF inference endpoint, we don't know the model name
# we'll use gpt2 BPE tokenizer for prompt length calculation
self.prompt_handler = DefaultPromptHandler(
model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
)
else:
self.prompt_handler = DefaultPromptHandler(
model_name_or_path=model_name_or_path,
model_max_length=model_max_length,
max_length=self.max_length or 100,
)
def preprocess_prompt(self, prompt: str):
for key, prompt_preprocessor in self.prompt_preprocessors.items():
if key in self.model_name_or_path:
return prompt_preprocessor(prompt)
return prompt
@property
def url(self) -> str:
@ -102,58 +137,99 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
stop_words = kwargs.pop("stop_words", None)
prompt = self.preprocess_prompt(prompt)
stop_words = kwargs.pop("stop_words", None) or []
kwargs_with_defaults = self.model_input_kwargs
if "max_new_tokens" not in kwargs_with_defaults:
kwargs_with_defaults["max_new_tokens"] = self.max_length
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["num_return_sequences"] = top_k
kwargs_with_defaults.update(kwargs)
# see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
accepted_params = [
"top_p",
"top_k",
"temperature",
"repetition_penalty",
"max_new_tokens",
"max_time",
"return_full_text",
"num_return_sequences",
"do_sample",
]
params = {key: kwargs_with_defaults.get(key) for key in accepted_params if key in kwargs_with_defaults}
generated_texts = self._post(data={"inputs": prompt, "parameters": params}, **kwargs)
if stop_words:
for idx, _ in enumerate(generated_texts):
earliest_stop_word_idx = len(generated_texts[idx])
for stop_word in stop_words:
stop_word_idx = generated_texts[idx].find(stop_word)
if stop_word_idx != -1:
earliest_stop_word_idx = min(earliest_stop_word_idx, stop_word_idx)
generated_texts[idx] = generated_texts[idx][:earliest_stop_word_idx]
# either stream is True (will use default handler) or stream_handler is provided
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
# see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
params = {
"best_of": kwargs_with_defaults.get("best_of", None),
"details": kwargs_with_defaults.get("details", True),
"do_sample": kwargs_with_defaults.get("do_sample", False),
"max_new_tokens": kwargs_with_defaults.get("max_new_tokens", self.max_length),
"max_time": kwargs_with_defaults.get("max_time", None),
"num_return_sequences": kwargs_with_defaults.get("num_return_sequences", None),
"repetition_penalty": kwargs_with_defaults.get("repetition_penalty", None),
"return_full_text": kwargs_with_defaults.get("return_full_text", False),
"seed": kwargs_with_defaults.get("seed", None),
"stop": kwargs_with_defaults.get("stop", stop_words),
"temperature": kwargs_with_defaults.get("temperature", None),
"top_k": kwargs_with_defaults.get("top_k", None),
"top_p": kwargs_with_defaults.get("top_p", None),
"truncate": kwargs_with_defaults.get("truncate", None),
"typical_p": kwargs_with_defaults.get("typical_p", None),
"watermark": kwargs_with_defaults.get("watermark", False),
}
response: requests.Response = self._post(
data={"inputs": prompt, "parameters": params, "stream": stream}, stream=stream
)
if stream:
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
generated_texts = self._process_streaming_response(response, handler, stop_words)
else:
output = json.loads(response.text)
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
return generated_texts
def _process_streaming_response(
self, response: requests.Response, stream_handler: TokenStreamingHandler, stop_words: List[str]
) -> List[str]:
"""
Stream the response and invoke the stream_handler on each token.
:param response: The response object from the server.
:param stream_handler: The handler to invoke on each token.
:param stop_words: The stop words to ignore.
"""
client = sseclient.SSEClient(response)
tokens: List[str] = []
try:
for event in client.events():
if event.data != TokenStreamingHandler.DONE_MARKER:
event_data = json.loads(event.data)
token: Optional[str] = self._extract_token(event_data)
# if valid token and not a stop words (we don't want to return stop words)
if token and token.strip() not in stop_words:
tokens.append(stream_handler(token, event_data=event_data))
finally:
client.close()
return ["".join(tokens)] # return a list of strings just like non-streaming
def _extract_token(self, event_data: Dict[str, Any]) -> Optional[str]:
"""
Extract the token from the event data. If the token is a special token, return None.
param event_data: Event data from the streaming response.
"""
# extract token from event data and only consider non-special tokens
return event_data["token"]["text"] if not event_data["token"]["special"] else None
def _post(
self,
data: Dict[str, Any],
stream: bool = False,
attempts: int = HF_RETRIES,
status_codes: Optional[List[int]] = None,
timeout: float = HF_TIMEOUT,
**kwargs,
) -> List[str]:
) -> requests.Response:
"""
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST invocation.
:param data: The data to be sent to the model.
:param stream: Whether to stream the response.
:param attempts: The number of attempts to make.
:param status_codes: The status codes to retry on.
:param timeout: The timeout for the request.
:return: The responses are being returned.
"""
generated_texts: List[str] = []
response: requests.Response
if status_codes is None:
status_codes = [429]
try:
@ -165,9 +241,8 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
headers=self.headers,
json=data,
timeout=timeout,
stream=stream,
)
output = json.loads(response.text)
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
except requests.HTTPError as err:
res = err.response
if res.status_code == 429:
@ -179,11 +254,22 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
f"HuggingFace Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
status_code=res.status_code,
)
return generated_texts
return response
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# TODO: new implementation incoming for all layers, let's omit this for now
return prompt
# the prompt for this model will be of the type str
resize_info = self.prompt_handler(prompt) # type: ignore
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
"answer length (%s tokens) fit within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off.",
resize_info["prompt_length"],
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
resize_info["max_length"],
resize_info["model_max_length"],
)
return str(resize_info["resized_prompt"])
@staticmethod
def is_inference_endpoint(model_name_or_path: str) -> bool:

View File

@ -0,0 +1,355 @@
import logging
import unittest
from unittest.mock import patch, MagicMock
import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
@pytest.mark.unit
def test_default_constructor():
"""
Test that the default constructor sets the correct values
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
assert layer.api_key == "some_fake_key"
assert layer.max_length == 100
assert layer.model_input_kwargs == {}
@pytest.mark.unit
def test_constructor_with_model_kwargs():
"""
Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out
"""
model_kwargs = {"temperature": 0.7, "do_sample": True, "stream": True}
model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1}
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected
)
assert "temperature" in layer.model_input_kwargs
assert "do_sample" in layer.model_input_kwargs
assert "stream" in layer.model_input_kwargs
assert "fake_param" not in layer.model_input_kwargs
assert "another_fake_param" not in layer.model_input_kwargs
@pytest.mark.unit
def test_set_model_max_length():
"""
Test that model max length is set correctly
"""
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key", model_max_length=2048
)
assert layer.prompt_handler.model_max_length == 2048
@pytest.mark.unit
def test_url():
"""
Test that the url is correctly set in the constructor
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
assert layer.url == "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="https://23445.us-east-1.aws.endpoints.huggingface.cloud", api_key="some_fake_key"
)
assert layer.url == "https://23445.us-east-1.aws.endpoints.huggingface.cloud"
@pytest.mark.unit
def test_invoke_with_no_kwargs():
"""
Test that invoke raises an error if no prompt is provided
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="some_fake_key")
with pytest.raises(ValueError) as e:
layer.invoke()
assert e.match("No prompt provided.")
@pytest.mark.unit
def test_invoke_with_stop_words():
"""
Test stop words are correctly passed to HTTP POST request
"""
stop_words = ["but", "not", "bye"]
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
assert "stop" in called_kwargs["data"]["parameters"]
assert called_kwargs["data"]["parameters"]["stop"] == stop_words
@pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_constructor(stream):
"""
Test stream parameter is correctly passed to HTTP POST request via constructor
"""
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream=stream
)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert called_kwargs["stream"] == stream
@pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_method(stream):
"""
Test stream parameter is correctly passed to HTTP POST request via method
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Assert that the 'stream' parameter passed to _post is the same as the one used in layer.invoke()
assert called_kwargs["stream"] == stream
@pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor():
"""
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
"""
stream_handler = DefaultTokenStreamingHandler()
layer = HFInferenceEndpointInvocationLayer(
model_name_or_path="google/flan-t5-xxl", api_key="fake_key", stream_handler=stream_handler
)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post, unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert called_kwargs["stream"]
# stream_handler is passed as an instance of TokenStreamingHandler
called_args, called_kwargs = mock_post_stream.call_args
assert isinstance(called_args[1], TokenStreamingHandler)
@pytest.mark.unit
def test_streaming_no_stream_handler_param_in_constructor():
"""
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# but it is False if stream_handler is None
assert not called_kwargs["stream"]
@pytest.mark.unit
def test_streaming_stream_handler_param_in_method():
"""
Test stream_handler parameter is correctly passed to HTTP POST request via method
"""
stream_handler = DefaultTokenStreamingHandler()
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post, unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._process_streaming_response"
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is correctly passed to _post
assert "stream" in called_kwargs
assert called_kwargs["stream"]
called_args, called_kwargs = mock_post_stream.call_args
assert isinstance(called_args[1], TokenStreamingHandler)
@pytest.mark.unit
def test_streaming_no_stream_handler_param_in_method():
"""
Test stream_handler parameter is correctly passed to HTTP POST request via method
"""
layer = HFInferenceEndpointInvocationLayer(model_name_or_path="google/flan-t5-xxl", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stream_handler=None)
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is always correctly passed to _post
assert "stream" in called_kwargs
assert not called_kwargs["stream"]
@pytest.mark.integration
@pytest.mark.parametrize(
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
)
def test_ensure_token_limit_no_resize(model_name_or_path):
# In this test case we assume that no prompt resizing is needed for all models
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=100)
# Define prompt and expected results
prompt = "This is a test prompt."
resized_prompt = handler._ensure_token_limit(prompt)
# Verify the results
assert resized_prompt == prompt
@pytest.mark.integration
@pytest.mark.parametrize(
"model_name_or_path", ["google/flan-t5-xxl", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"]
)
def test_ensure_token_limit_resize(caplog, model_name_or_path):
# In this test case we assume prompt resizing is needed for all models
handler = HFInferenceEndpointInvocationLayer("fake_api_key", model_name_or_path, max_length=5, model_max_length=10)
# Define prompt and expected results
prompt = "This is a test prompt that will be resized because model_max_length is 10 and max_length is 5."
with caplog.at_level(logging.WARN):
resized_prompt = handler._ensure_token_limit(prompt)
assert "The prompt has been truncated" in caplog.text
# Verify the results
assert resized_prompt != prompt
assert (
"This is a test" in resized_prompt
and "because model_max_length is 10 and max_length is 5" not in resized_prompt
)
@pytest.mark.unit
def test_oasst_prompt_preprocessing():
model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.HFInferenceEndpointInvocationLayer._post"
) as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
result = layer.invoke(prompt="Tell me hello")
assert result == ["Hello"]
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
@pytest.mark.unit
def test_invalid_key():
with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
layer = HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
@pytest.mark.unit
def test_invalid_model():
with pytest.raises(ValueError, match="cannot be None or empty string"):
layer = HFInferenceEndpointInvocationLayer("fake_api", "")
@pytest.mark.unit
def test_supports():
"""
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
"""
# doesn't support fake model
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
# supports google/flan-t5-xxl with api_key
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
# doesn't support google/flan-t5-xxl without api_key
assert not HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl")
# supports HF Inference Endpoint with api_key
assert HFInferenceEndpointInvocationLayer.supports(
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
)

View File

@ -58,3 +58,21 @@ def test_flan_prompt_handler():
"model_max_length": 20,
"new_prompt_length": 10,
}
# test corner cases
assert handler("") == {
"prompt_length": 0,
"resized_prompt": "",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 0,
}
# test corner case
assert handler(None) == {
"prompt_length": 0,
"resized_prompt": None,
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 0,
}