feat: Add Cohere PromptNode invocation layer (#4827)

* Add CohereInvocationLayer
---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Vladimir Blagojevic 2023-05-12 17:50:09 +02:00 committed by GitHub
parent 7e2b824bea
commit 73380b194a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 515 additions and 3 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: python
search_path: [../../../haystack/nodes/prompt]
modules:
modules:
[
"prompt_node",
"prompt_model",

View File

@ -230,3 +230,10 @@ class AnthropicUnauthorizedError(AnthropicError):
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event)
class CohereInferenceLimitError(CohereError):
"""Exception for issues that occur in the Cohere inference node due to rate limiting"""
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event)

View File

@ -7,3 +7,4 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocatio
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer

View File

@ -0,0 +1,220 @@
import json
import os
from typing import Optional, Dict, Union, List, Any
import logging
import requests
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
from haystack.errors import CohereInferenceLimitError, CohereUnauthorizedError, CohereError
from haystack.nodes.prompt.invocation_layer import (
PromptModelInvocationLayer,
TokenStreamingHandler,
DefaultTokenStreamingHandler,
)
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
from haystack.utils.requests import request_with_retry
logger = logging.getLogger(__name__)
TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
class CohereInvocationLayer(PromptModelInvocationLayer):
"""
PromptModelInvocationLayer implementation for Cohere's command models. Invocations are made using REST API.
"""
def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs):
"""
Creates an instance of CohereInvocationLayer for the specified Cohere model
:param api_key: Cohere API key
:param model_name_or_path: Cohere model name
:param max_length: The maximum length of the output text.
"""
super().__init__(model_name_or_path)
valid_api_key = isinstance(api_key, str) and api_key
if not valid_api_key:
raise ValueError(
f"api_key {api_key} must be a valid Cohere token. "
f"Your token is available in your Cohere settings page."
)
valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path
if not valid_model_name_or_path:
raise ValueError(f"model_name_or_path {model_name_or_path} must be a valid Cohere model name")
self.api_key = api_key
self.max_length = max_length
# See https://docs.cohere.com/reference/generate
# for a list of supported parameters
self.model_input_kwargs = {
key: kwargs[key]
for key in [
"end_sequences",
"frequency_penalty",
"k",
"logit_bias",
"max_tokens",
"model",
"num_generations",
"p",
"presence_penalty",
"return_likelihoods",
"stream",
"stream_handler",
"temperature",
"truncate",
]
if key in kwargs
}
# cohere uses BPE tokenizer
# the tokenization lengths are very close to gpt2, in our experiments the differences were minimal
# See model info at https://docs.cohere.com/docs/models
model_max_length = 4096 if "command" in model_name_or_path else 2048
self.prompt_handler = DefaultPromptHandler(
model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
)
@property
def url(self) -> str:
return "https://api.cohere.ai/v1/generate"
@property
def headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Request-Source": "python-sdk",
}
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
:return: The responses are being returned.
"""
prompt = kwargs.get("prompt")
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
stop_words = kwargs.pop("stop_words", None)
kwargs_with_defaults = self.model_input_kwargs
kwargs_with_defaults.update(kwargs)
# either stream is True (will use default handler) or stream_handler is provided
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
# see https://docs.cohere.com/reference/generate
params = {
"end_sequences": kwargs_with_defaults.get("end_sequences", stop_words),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", None),
"k": kwargs_with_defaults.get("k", None),
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"model": kwargs_with_defaults.get("model", self.model_name_or_path),
"num_generations": kwargs_with_defaults.get("num_generations", None),
"p": kwargs_with_defaults.get("p", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", None),
"prompt": prompt,
"return_likelihoods": kwargs_with_defaults.get("return_likelihoods", None),
"stream": stream,
"temperature": kwargs_with_defaults.get("temperature", None),
"truncate": kwargs_with_defaults.get("truncate", None),
}
response = self._post(params, stream=stream)
if not stream:
output = json.loads(response.text)
generated_texts = [o["text"] for o in output["generations"] if "text" in o]
else:
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
generated_texts = self._process_streaming_response(response=response, stream_handler=handler)
return generated_texts
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
# sseclient doesn't work with Cohere streaming API
# let's do it manually
tokens = []
for line in response.iter_lines():
if line:
streaming_item = json.loads(line)
text = streaming_item.get("text")
if text:
tokens.append(stream_handler(text))
return ["".join(tokens)] # return a list of strings just like non-streaming
def _post(
self,
data: Dict[str, Any],
stream: bool = False,
attempts: int = RETRIES,
status_codes: Optional[List[int]] = None,
timeout: float = TIMEOUT,
**kwargs,
) -> requests.Response:
"""
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST
invocation.
:param data: The data to be sent to the model.
:param stream: Whether to stream the response.
:param attempts: The number of attempts to make.
:param status_codes: The status codes to retry on.
:param timeout: The timeout for the request.
:return: The response from the model as a requests.Response object.
"""
response: requests.Response
if status_codes is None:
status_codes = [429]
try:
response = request_with_retry(
method="POST",
status_codes=status_codes,
attempts=attempts,
url=self.url,
headers=self.headers,
json=data,
timeout=timeout,
stream=stream,
)
except requests.HTTPError as err:
res = err.response
if res.status_code == 429:
raise CohereInferenceLimitError(f"API rate limit exceeded: {res.text}")
if res.status_code == 401:
raise CohereUnauthorizedError(f"API key is invalid: {res.text}")
raise CohereError(
f"Cohere model returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
status_code=res.status_code,
)
return response
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# the prompt for this model will be of the type str
resize_info = self.prompt_handler(prompt) # type: ignore
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
"answer length (%s tokens) fit within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off",
resize_info["prompt_length"],
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
resize_info["max_length"],
resize_info["model_max_length"],
)
return prompt
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
"""
Ensures CohereInvocationLayer is selected only when Cohere models are specified in
the model name.
"""
is_inference_api = "api_key" in kwargs
return (
model_name_or_path is not None
and is_inference_api
and any(token == model_name_or_path for token in ["command", "command-light", "base", "base-light"])
)

View File

@ -1,7 +1,7 @@
from abc import abstractmethod, ABC
from typing import Union
from typing import Union, Dict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer, AutoTokenizer
class TokenStreamingHandler(ABC):
@ -46,3 +46,47 @@ class HFTokenStreamingHandler(TextStreamer):
def on_finalized_text(self, token: str, stream_end: bool = False):
token_to_send = token + "\n" if stream_end else token
self.token_handler(token_received=token_to_send, **{})
class DefaultPromptHandler:
"""
DefaultPromptHandler resizes the prompt to ensure that the prompt and answer token lengths together
are within the model_max_length.
"""
def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model_max_length = model_max_length
self.max_length = max_length
def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]:
"""
Resizes the prompt to ensure that the prompt and answer is within the model_max_length
:param prompt: the prompt to be sent to the model.
:param kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the resized prompt and additional information.
"""
resized_prompt = prompt
prompt_length = 0
new_prompt_length = 0
if prompt:
prompt_length = len(self.tokenizer.tokenize(prompt))
if (prompt_length + self.max_length) <= self.model_max_length:
resized_prompt = prompt
new_prompt_length = prompt_length
else:
tokenized_payload = self.tokenizer.tokenize(prompt)
resized_prompt = self.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.model_max_length - self.max_length]
)
new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length])
return {
"resized_prompt": resized_prompt,
"prompt_length": prompt_length,
"new_prompt_length": new_prompt_length,
"model_max_length": self.model_max_length,
"max_length": self.max_length,
}

View File

@ -0,0 +1,180 @@
import unittest
from unittest.mock import patch, MagicMock
import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer
@pytest.mark.unit
def test_default_constructor():
"""
Test that the default constructor sets the correct values
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
assert layer.api_key == "some_fake_key"
assert layer.max_length == 100
assert layer.model_input_kwargs == {}
assert layer.prompt_handler.model_max_length == 4096
layer = CohereInvocationLayer(model_name_or_path="base", api_key="some_fake_key")
assert layer.api_key == "some_fake_key"
assert layer.max_length == 100
assert layer.model_input_kwargs == {}
assert layer.prompt_handler.model_max_length == 2048
@pytest.mark.unit
def test_constructor_with_model_kwargs():
"""
Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out
"""
model_kwargs = {"temperature": 0.7, "end_sequences": ["end"], "stream": True}
model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1}
layer = CohereInvocationLayer(
model_name_or_path="command", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected
)
assert layer.model_input_kwargs == model_kwargs
assert len(model_kwargs_rejected) == 2
@pytest.mark.unit
def test_invoke_with_no_kwargs():
"""
Test that invoke raises an error if no prompt is provided
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
with pytest.raises(ValueError) as e:
layer.invoke()
assert e.match("No prompt provided.")
@pytest.mark.unit
def test_invoke_with_stop_words():
"""
Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer
"""
stop_words = ["but", "not", "bye"]
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, _ = mock_post.call_args
assert "end_sequences" in called_args[0]
assert called_args[0]["end_sequences"] == stop_words
@pytest.mark.unit
@pytest.mark.parametrize("using_constructor", [True, False])
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param(using_constructor, stream):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer
"""
if using_constructor:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream)
else:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
if using_constructor:
layer.invoke(prompt="Tell me hello")
else:
layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Check if stream is True, then stream is passed as True to _post
if stream:
assert called_kwargs["stream"]
# Check if stream is False, then stream is passed as False to _post
else:
assert not called_kwargs["stream"]
@pytest.mark.unit
@pytest.mark.parametrize("using_constructor", [True, False])
@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None])
def test_streaming_stream_handler_param(using_constructor, stream_handler):
"""
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
"""
if using_constructor:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
else:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post"
) as mock_post, unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response"
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
if using_constructor:
layer.invoke(prompt="Tell me hello")
else:
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# if stream_handler is used then stream is always True
if stream_handler:
assert called_kwargs["stream"]
# and stream_handler is passed as an instance of TokenStreamingHandler
called_args, called_kwargs = mock_post_stream.call_args
assert "stream_handler" in called_kwargs
assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler)
# if stream_handler is not used then stream is always False
else:
assert not called_kwargs["stream"]
@pytest.mark.unit
def test_supports():
"""
Test that supports returns True correctly for CohereInvocationLayer
"""
# See command and generate models at https://docs.cohere.com/docs/models
# doesn't support fake model
assert not CohereInvocationLayer.supports("fake_model", api_key="fake_key")
# supports cohere command with api_key
assert CohereInvocationLayer.supports("command", api_key="fake_key")
# supports cohere command-light with api_key
assert CohereInvocationLayer.supports("command-light", api_key="fake_key")
# supports cohere base with api_key
assert CohereInvocationLayer.supports("base", api_key="fake_key")
assert CohereInvocationLayer.supports("base-light", api_key="fake_key")
# doesn't support other models that have base substring only i.e. google/flan-t5-base
assert not CohereInvocationLayer.supports("google/flan-t5-base")

View File

@ -0,0 +1,60 @@
import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
@pytest.mark.integration
def test_prompt_handler_basics():
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
assert callable(handler)
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
assert handler.max_length == 100
@pytest.mark.integration
def test_gpt2_prompt_handler():
# test gpt2 BPE based tokenizer
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == {
"prompt_length": 4,
"resized_prompt": "This is a test",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 4,
}
# test resize
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 15,
"resized_prompt": "This is a prompt that will be resized because",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 10,
}
@pytest.mark.integration
def test_flan_prompt_handler():
# test google/flan-t5-xxl tokenizer
handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10)
# test no resize
assert handler("This is a test") == {
"prompt_length": 5,
"resized_prompt": "This is a test",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 5,
}
# test resize
assert handler("This is a prompt that will be resized because it is longer than allowed") == {
"prompt_length": 17,
"resized_prompt": "This is a prompt that will be re",
"max_length": 10,
"model_max_length": 20,
"new_prompt_length": 10,
}