mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
feat: Add Cohere PromptNode invocation layer (#4827)
* Add CohereInvocationLayer --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
7e2b824bea
commit
73380b194a
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: python
|
||||
search_path: [../../../haystack/nodes/prompt]
|
||||
modules:
|
||||
modules:
|
||||
[
|
||||
"prompt_node",
|
||||
"prompt_model",
|
||||
|
||||
@ -230,3 +230,10 @@ class AnthropicUnauthorizedError(AnthropicError):
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event)
|
||||
|
||||
|
||||
class CohereInferenceLimitError(CohereError):
|
||||
"""Exception for issues that occur in the Cohere inference node due to rate limiting"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event)
|
||||
|
||||
@ -7,3 +7,4 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocatio
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
|
||||
|
||||
220
haystack/nodes/prompt/invocation_layer/cohere.py
Normal file
220
haystack/nodes/prompt/invocation_layer/cohere.py
Normal file
@ -0,0 +1,220 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Dict, Union, List, Any
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
||||
from haystack.errors import CohereInferenceLimitError, CohereUnauthorizedError, CohereError
|
||||
from haystack.nodes.prompt.invocation_layer import (
|
||||
PromptModelInvocationLayer,
|
||||
TokenStreamingHandler,
|
||||
DefaultTokenStreamingHandler,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
|
||||
from haystack.utils.requests import request_with_retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
|
||||
RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||
|
||||
|
||||
class CohereInvocationLayer(PromptModelInvocationLayer):
|
||||
"""
|
||||
PromptModelInvocationLayer implementation for Cohere's command models. Invocations are made using REST API.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs):
|
||||
"""
|
||||
Creates an instance of CohereInvocationLayer for the specified Cohere model
|
||||
|
||||
:param api_key: Cohere API key
|
||||
:param model_name_or_path: Cohere model name
|
||||
:param max_length: The maximum length of the output text.
|
||||
"""
|
||||
super().__init__(model_name_or_path)
|
||||
valid_api_key = isinstance(api_key, str) and api_key
|
||||
if not valid_api_key:
|
||||
raise ValueError(
|
||||
f"api_key {api_key} must be a valid Cohere token. "
|
||||
f"Your token is available in your Cohere settings page."
|
||||
)
|
||||
valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path
|
||||
if not valid_model_name_or_path:
|
||||
raise ValueError(f"model_name_or_path {model_name_or_path} must be a valid Cohere model name")
|
||||
self.api_key = api_key
|
||||
self.max_length = max_length
|
||||
|
||||
# See https://docs.cohere.com/reference/generate
|
||||
# for a list of supported parameters
|
||||
self.model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
"end_sequences",
|
||||
"frequency_penalty",
|
||||
"k",
|
||||
"logit_bias",
|
||||
"max_tokens",
|
||||
"model",
|
||||
"num_generations",
|
||||
"p",
|
||||
"presence_penalty",
|
||||
"return_likelihoods",
|
||||
"stream",
|
||||
"stream_handler",
|
||||
"temperature",
|
||||
"truncate",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
# cohere uses BPE tokenizer
|
||||
# the tokenization lengths are very close to gpt2, in our experiments the differences were minimal
|
||||
# See model info at https://docs.cohere.com/docs/models
|
||||
model_max_length = 4096 if "command" in model_name_or_path else 2048
|
||||
self.prompt_handler = DefaultPromptHandler(
|
||||
model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
|
||||
)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return "https://api.cohere.ai/v1/generate"
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Request-Source": "python-sdk",
|
||||
}
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
|
||||
:return: The responses are being returned.
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
stop_words = kwargs.pop("stop_words", None)
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
|
||||
# either stream is True (will use default handler) or stream_handler is provided
|
||||
stream = (
|
||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||
)
|
||||
|
||||
# see https://docs.cohere.com/reference/generate
|
||||
params = {
|
||||
"end_sequences": kwargs_with_defaults.get("end_sequences", stop_words),
|
||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", None),
|
||||
"k": kwargs_with_defaults.get("k", None),
|
||||
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
|
||||
"model": kwargs_with_defaults.get("model", self.model_name_or_path),
|
||||
"num_generations": kwargs_with_defaults.get("num_generations", None),
|
||||
"p": kwargs_with_defaults.get("p", None),
|
||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", None),
|
||||
"prompt": prompt,
|
||||
"return_likelihoods": kwargs_with_defaults.get("return_likelihoods", None),
|
||||
"stream": stream,
|
||||
"temperature": kwargs_with_defaults.get("temperature", None),
|
||||
"truncate": kwargs_with_defaults.get("truncate", None),
|
||||
}
|
||||
response = self._post(params, stream=stream)
|
||||
if not stream:
|
||||
output = json.loads(response.text)
|
||||
generated_texts = [o["text"] for o in output["generations"] if "text" in o]
|
||||
else:
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
generated_texts = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
return generated_texts
|
||||
|
||||
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
|
||||
# sseclient doesn't work with Cohere streaming API
|
||||
# let's do it manually
|
||||
tokens = []
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
streaming_item = json.loads(line)
|
||||
text = streaming_item.get("text")
|
||||
if text:
|
||||
tokens.append(stream_handler(text))
|
||||
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||
|
||||
def _post(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
attempts: int = RETRIES,
|
||||
status_codes: Optional[List[int]] = None,
|
||||
timeout: float = TIMEOUT,
|
||||
**kwargs,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST
|
||||
invocation.
|
||||
:param data: The data to be sent to the model.
|
||||
:param stream: Whether to stream the response.
|
||||
:param attempts: The number of attempts to make.
|
||||
:param status_codes: The status codes to retry on.
|
||||
:param timeout: The timeout for the request.
|
||||
:return: The response from the model as a requests.Response object.
|
||||
"""
|
||||
response: requests.Response
|
||||
if status_codes is None:
|
||||
status_codes = [429]
|
||||
try:
|
||||
response = request_with_retry(
|
||||
method="POST",
|
||||
status_codes=status_codes,
|
||||
attempts=attempts,
|
||||
url=self.url,
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
except requests.HTTPError as err:
|
||||
res = err.response
|
||||
if res.status_code == 429:
|
||||
raise CohereInferenceLimitError(f"API rate limit exceeded: {res.text}")
|
||||
if res.status_code == 401:
|
||||
raise CohereUnauthorizedError(f"API key is invalid: {res.text}")
|
||||
|
||||
raise CohereError(
|
||||
f"Cohere model returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
|
||||
status_code=res.status_code,
|
||||
)
|
||||
return response
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
# the prompt for this model will be of the type str
|
||||
resize_info = self.prompt_handler(prompt) # type: ignore
|
||||
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
|
||||
logger.warning(
|
||||
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
|
||||
"answer length (%s tokens) fit within the max token limit (%s tokens). "
|
||||
"Shorten the prompt to prevent it from being cut off",
|
||||
resize_info["prompt_length"],
|
||||
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
|
||||
resize_info["max_length"],
|
||||
resize_info["model_max_length"],
|
||||
)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
"""
|
||||
Ensures CohereInvocationLayer is selected only when Cohere models are specified in
|
||||
the model name.
|
||||
"""
|
||||
is_inference_api = "api_key" in kwargs
|
||||
return (
|
||||
model_name_or_path is not None
|
||||
and is_inference_api
|
||||
and any(token == model_name_or_path for token in ["command", "command-light", "base", "base-light"])
|
||||
)
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Union
|
||||
from typing import Union, Dict
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer, AutoTokenizer
|
||||
|
||||
|
||||
class TokenStreamingHandler(ABC):
|
||||
@ -46,3 +46,47 @@ class HFTokenStreamingHandler(TextStreamer):
|
||||
def on_finalized_text(self, token: str, stream_end: bool = False):
|
||||
token_to_send = token + "\n" if stream_end else token
|
||||
self.token_handler(token_received=token_to_send, **{})
|
||||
|
||||
|
||||
class DefaultPromptHandler:
|
||||
"""
|
||||
DefaultPromptHandler resizes the prompt to ensure that the prompt and answer token lengths together
|
||||
are within the model_max_length.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.model_max_length = model_max_length
|
||||
self.max_length = max_length
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]:
|
||||
"""
|
||||
Resizes the prompt to ensure that the prompt and answer is within the model_max_length
|
||||
|
||||
:param prompt: the prompt to be sent to the model.
|
||||
:param kwargs: Additional keyword arguments passed to the handler.
|
||||
:return: A dictionary containing the resized prompt and additional information.
|
||||
"""
|
||||
resized_prompt = prompt
|
||||
prompt_length = 0
|
||||
new_prompt_length = 0
|
||||
|
||||
if prompt:
|
||||
prompt_length = len(self.tokenizer.tokenize(prompt))
|
||||
if (prompt_length + self.max_length) <= self.model_max_length:
|
||||
resized_prompt = prompt
|
||||
new_prompt_length = prompt_length
|
||||
else:
|
||||
tokenized_payload = self.tokenizer.tokenize(prompt)
|
||||
resized_prompt = self.tokenizer.convert_tokens_to_string(
|
||||
tokenized_payload[: self.model_max_length - self.max_length]
|
||||
)
|
||||
new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length])
|
||||
|
||||
return {
|
||||
"resized_prompt": resized_prompt,
|
||||
"prompt_length": prompt_length,
|
||||
"new_prompt_length": new_prompt_length,
|
||||
"model_max_length": self.model_max_length,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
180
test/prompt/invocation_layer/test_cohere.py
Normal file
180
test/prompt/invocation_layer/test_cohere.py
Normal file
@ -0,0 +1,180 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_constructor():
|
||||
"""
|
||||
Test that the default constructor sets the correct values
|
||||
"""
|
||||
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
|
||||
|
||||
assert layer.api_key == "some_fake_key"
|
||||
assert layer.max_length == 100
|
||||
assert layer.model_input_kwargs == {}
|
||||
assert layer.prompt_handler.model_max_length == 4096
|
||||
|
||||
layer = CohereInvocationLayer(model_name_or_path="base", api_key="some_fake_key")
|
||||
assert layer.api_key == "some_fake_key"
|
||||
assert layer.max_length == 100
|
||||
assert layer.model_input_kwargs == {}
|
||||
assert layer.prompt_handler.model_max_length == 2048
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_kwargs():
|
||||
"""
|
||||
Test that model_kwargs are correctly set in the constructor
|
||||
and that model_kwargs_rejected are correctly filtered out
|
||||
"""
|
||||
model_kwargs = {"temperature": 0.7, "end_sequences": ["end"], "stream": True}
|
||||
model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1}
|
||||
layer = CohereInvocationLayer(
|
||||
model_name_or_path="command", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected
|
||||
)
|
||||
assert layer.model_input_kwargs == model_kwargs
|
||||
assert len(model_kwargs_rejected) == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_no_kwargs():
|
||||
"""
|
||||
Test that invoke raises an error if no prompt is provided
|
||||
"""
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
|
||||
with pytest.raises(ValueError) as e:
|
||||
layer.invoke()
|
||||
assert e.match("No prompt provided.")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_stop_words():
|
||||
"""
|
||||
Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer
|
||||
"""
|
||||
stop_words = ["but", "not", "bye"]
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, _ = mock_post.call_args
|
||||
assert "end_sequences" in called_args[0]
|
||||
assert called_args[0]["end_sequences"] == stop_words
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("using_constructor", [True, False])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_streaming_stream_param(using_constructor, stream):
|
||||
"""
|
||||
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer
|
||||
"""
|
||||
if using_constructor:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream)
|
||||
else:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
|
||||
|
||||
if using_constructor:
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
else:
|
||||
layer.invoke(prompt="Tell me hello", stream=stream)
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# Check if stream is True, then stream is passed as True to _post
|
||||
if stream:
|
||||
assert called_kwargs["stream"]
|
||||
# Check if stream is False, then stream is passed as False to _post
|
||||
else:
|
||||
assert not called_kwargs["stream"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("using_constructor", [True, False])
|
||||
@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None])
|
||||
def test_streaming_stream_handler_param(using_constructor, stream_handler):
|
||||
"""
|
||||
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
|
||||
"""
|
||||
if using_constructor:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
|
||||
else:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post"
|
||||
) as mock_post, unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response"
|
||||
) as mock_post_stream:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
|
||||
|
||||
if using_constructor:
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
else:
|
||||
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# if stream_handler is used then stream is always True
|
||||
if stream_handler:
|
||||
assert called_kwargs["stream"]
|
||||
# and stream_handler is passed as an instance of TokenStreamingHandler
|
||||
called_args, called_kwargs = mock_post_stream.call_args
|
||||
assert "stream_handler" in called_kwargs
|
||||
assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler)
|
||||
# if stream_handler is not used then stream is always False
|
||||
else:
|
||||
assert not called_kwargs["stream"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports():
|
||||
"""
|
||||
Test that supports returns True correctly for CohereInvocationLayer
|
||||
"""
|
||||
# See command and generate models at https://docs.cohere.com/docs/models
|
||||
# doesn't support fake model
|
||||
assert not CohereInvocationLayer.supports("fake_model", api_key="fake_key")
|
||||
|
||||
# supports cohere command with api_key
|
||||
assert CohereInvocationLayer.supports("command", api_key="fake_key")
|
||||
|
||||
# supports cohere command-light with api_key
|
||||
assert CohereInvocationLayer.supports("command-light", api_key="fake_key")
|
||||
|
||||
# supports cohere base with api_key
|
||||
assert CohereInvocationLayer.supports("base", api_key="fake_key")
|
||||
|
||||
assert CohereInvocationLayer.supports("base-light", api_key="fake_key")
|
||||
|
||||
# doesn't support other models that have base substring only i.e. google/flan-t5-base
|
||||
assert not CohereInvocationLayer.supports("google/flan-t5-base")
|
||||
60
test/prompt/test_handlers.py
Normal file
60
test/prompt/test_handlers.py
Normal file
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_prompt_handler_basics():
|
||||
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
||||
assert callable(handler)
|
||||
|
||||
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
|
||||
assert handler.max_length == 100
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_gpt2_prompt_handler():
|
||||
# test gpt2 BPE based tokenizer
|
||||
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
||||
|
||||
# test no resize
|
||||
assert handler("This is a test") == {
|
||||
"prompt_length": 4,
|
||||
"resized_prompt": "This is a test",
|
||||
"max_length": 10,
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 4,
|
||||
}
|
||||
|
||||
# test resize
|
||||
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
||||
"prompt_length": 15,
|
||||
"resized_prompt": "This is a prompt that will be resized because",
|
||||
"max_length": 10,
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_flan_prompt_handler():
|
||||
# test google/flan-t5-xxl tokenizer
|
||||
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
|
||||
|
||||
# test no resize
|
||||
assert handler("This is a test") == {
|
||||
"prompt_length": 5,
|
||||
"resized_prompt": "This is a test",
|
||||
"max_length": 10,
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 5,
|
||||
}
|
||||
|
||||
# test resize
|
||||
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
|
||||
"prompt_length": 17,
|
||||
"resized_prompt": "This is a prompt that will be re",
|
||||
"max_length": 10,
|
||||
"model_max_length": 20,
|
||||
"new_prompt_length": 10,
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user