diff --git a/docs/pydoc/config/prompt-node.yml b/docs/pydoc/config/prompt-node.yml index 4524db0df..bd67c933f 100644 --- a/docs/pydoc/config/prompt-node.yml +++ b/docs/pydoc/config/prompt-node.yml @@ -1,7 +1,7 @@ loaders: - type: python search_path: [../../../haystack/nodes/prompt] - modules: + modules: [ "prompt_node", "prompt_model", diff --git a/haystack/errors.py b/haystack/errors.py index fc52db51c..ef00c746d 100644 --- a/haystack/errors.py +++ b/haystack/errors.py @@ -230,3 +230,10 @@ class AnthropicUnauthorizedError(AnthropicError): def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event) + + +class CohereInferenceLimitError(CohereError): + """Exception for issues that occur in the Cohere inference node due to rate limiting""" + + def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): + super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event) diff --git a/haystack/nodes/prompt/invocation_layer/__init__.py b/haystack/nodes/prompt/invocation_layer/__init__.py index b3ecaa447..2036f1b1d 100644 --- a/haystack/nodes/prompt/invocation_layer/__init__.py +++ b/haystack/nodes/prompt/invocation_layer/__init__.py @@ -7,3 +7,4 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocatio from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicClaudeInvocationLayer +from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer diff --git a/haystack/nodes/prompt/invocation_layer/cohere.py b/haystack/nodes/prompt/invocation_layer/cohere.py new file mode 100644 index 000000000..a15c9b9f2 --- /dev/null +++ b/haystack/nodes/prompt/invocation_layer/cohere.py @@ -0,0 +1,220 @@ +import json +import os +from typing import Optional, Dict, Union, List, Any +import logging + +import requests + +from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES +from haystack.errors import CohereInferenceLimitError, CohereUnauthorizedError, CohereError +from haystack.nodes.prompt.invocation_layer import ( + PromptModelInvocationLayer, + TokenStreamingHandler, + DefaultTokenStreamingHandler, +) +from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler +from haystack.utils.requests import request_with_retry + +logger = logging.getLogger(__name__) +TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) +RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)) + + +class CohereInvocationLayer(PromptModelInvocationLayer): + """ + PromptModelInvocationLayer implementation for Cohere's command models. Invocations are made using REST API. + """ + + def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs): + """ + Creates an instance of CohereInvocationLayer for the specified Cohere model + + :param api_key: Cohere API key + :param model_name_or_path: Cohere model name + :param max_length: The maximum length of the output text. + """ + super().__init__(model_name_or_path) + valid_api_key = isinstance(api_key, str) and api_key + if not valid_api_key: + raise ValueError( + f"api_key {api_key} must be a valid Cohere token. " + f"Your token is available in your Cohere settings page." + ) + valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path + if not valid_model_name_or_path: + raise ValueError(f"model_name_or_path {model_name_or_path} must be a valid Cohere model name") + self.api_key = api_key + self.max_length = max_length + + # See https://docs.cohere.com/reference/generate + # for a list of supported parameters + self.model_input_kwargs = { + key: kwargs[key] + for key in [ + "end_sequences", + "frequency_penalty", + "k", + "logit_bias", + "max_tokens", + "model", + "num_generations", + "p", + "presence_penalty", + "return_likelihoods", + "stream", + "stream_handler", + "temperature", + "truncate", + ] + if key in kwargs + } + # cohere uses BPE tokenizer + # the tokenization lengths are very close to gpt2, in our experiments the differences were minimal + # See model info at https://docs.cohere.com/docs/models + model_max_length = 4096 if "command" in model_name_or_path else 2048 + self.prompt_handler = DefaultPromptHandler( + model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 + ) + + @property + def url(self) -> str: + return "https://api.cohere.ai/v1/generate" + + @property + def headers(self) -> Dict[str, str]: + return { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Request-Source": "python-sdk", + } + + def invoke(self, *args, **kwargs): + """ + Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation. + :return: The responses are being returned. + """ + prompt = kwargs.get("prompt") + if not prompt: + raise ValueError( + f"No prompt provided. Model {self.model_name_or_path} requires prompt." + f"Make sure to provide prompt in kwargs." + ) + stop_words = kwargs.pop("stop_words", None) + kwargs_with_defaults = self.model_input_kwargs + kwargs_with_defaults.update(kwargs) + + # either stream is True (will use default handler) or stream_handler is provided + stream = ( + kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None + ) + + # see https://docs.cohere.com/reference/generate + params = { + "end_sequences": kwargs_with_defaults.get("end_sequences", stop_words), + "frequency_penalty": kwargs_with_defaults.get("frequency_penalty", None), + "k": kwargs_with_defaults.get("k", None), + "max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length), + "model": kwargs_with_defaults.get("model", self.model_name_or_path), + "num_generations": kwargs_with_defaults.get("num_generations", None), + "p": kwargs_with_defaults.get("p", None), + "presence_penalty": kwargs_with_defaults.get("presence_penalty", None), + "prompt": prompt, + "return_likelihoods": kwargs_with_defaults.get("return_likelihoods", None), + "stream": stream, + "temperature": kwargs_with_defaults.get("temperature", None), + "truncate": kwargs_with_defaults.get("truncate", None), + } + response = self._post(params, stream=stream) + if not stream: + output = json.loads(response.text) + generated_texts = [o["text"] for o in output["generations"] if "text" in o] + else: + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + generated_texts = self._process_streaming_response(response=response, stream_handler=handler) + return generated_texts + + def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler): + # sseclient doesn't work with Cohere streaming API + # let's do it manually + tokens = [] + for line in response.iter_lines(): + if line: + streaming_item = json.loads(line) + text = streaming_item.get("text") + if text: + tokens.append(stream_handler(text)) + return ["".join(tokens)] # return a list of strings just like non-streaming + + def _post( + self, + data: Dict[str, Any], + stream: bool = False, + attempts: int = RETRIES, + status_codes: Optional[List[int]] = None, + timeout: float = TIMEOUT, + **kwargs, + ) -> requests.Response: + """ + Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST + invocation. + :param data: The data to be sent to the model. + :param stream: Whether to stream the response. + :param attempts: The number of attempts to make. + :param status_codes: The status codes to retry on. + :param timeout: The timeout for the request. + :return: The response from the model as a requests.Response object. + """ + response: requests.Response + if status_codes is None: + status_codes = [429] + try: + response = request_with_retry( + method="POST", + status_codes=status_codes, + attempts=attempts, + url=self.url, + headers=self.headers, + json=data, + timeout=timeout, + stream=stream, + ) + except requests.HTTPError as err: + res = err.response + if res.status_code == 429: + raise CohereInferenceLimitError(f"API rate limit exceeded: {res.text}") + if res.status_code == 401: + raise CohereUnauthorizedError(f"API key is invalid: {res.text}") + + raise CohereError( + f"Cohere model returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}", + status_code=res.status_code, + ) + return response + + def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: + # the prompt for this model will be of the type str + resize_info = self.prompt_handler(prompt) # type: ignore + if resize_info["prompt_length"] != resize_info["new_prompt_length"]: + logger.warning( + "The prompt has been truncated from %s tokens to %s tokens so that the prompt length and " + "answer length (%s tokens) fit within the max token limit (%s tokens). " + "Shorten the prompt to prevent it from being cut off", + resize_info["prompt_length"], + max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore + resize_info["max_length"], + resize_info["model_max_length"], + ) + return prompt + + @classmethod + def supports(cls, model_name_or_path: str, **kwargs) -> bool: + """ + Ensures CohereInvocationLayer is selected only when Cohere models are specified in + the model name. + """ + is_inference_api = "api_key" in kwargs + return ( + model_name_or_path is not None + and is_inference_api + and any(token == model_name_or_path for token in ["command", "command-light", "base", "base-light"]) + ) diff --git a/haystack/nodes/prompt/invocation_layer/handlers.py b/haystack/nodes/prompt/invocation_layer/handlers.py index aa4405509..0872fc53e 100644 --- a/haystack/nodes/prompt/invocation_layer/handlers.py +++ b/haystack/nodes/prompt/invocation_layer/handlers.py @@ -1,7 +1,7 @@ from abc import abstractmethod, ABC -from typing import Union +from typing import Union, Dict -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer, AutoTokenizer class TokenStreamingHandler(ABC): @@ -46,3 +46,47 @@ class HFTokenStreamingHandler(TextStreamer): def on_finalized_text(self, token: str, stream_end: bool = False): token_to_send = token + "\n" if stream_end else token self.token_handler(token_received=token_to_send, **{}) + + +class DefaultPromptHandler: + """ + DefaultPromptHandler resizes the prompt to ensure that the prompt and answer token lengths together + are within the model_max_length. + """ + + def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100): + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.model_max_length = model_max_length + self.max_length = max_length + + def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]: + """ + Resizes the prompt to ensure that the prompt and answer is within the model_max_length + + :param prompt: the prompt to be sent to the model. + :param kwargs: Additional keyword arguments passed to the handler. + :return: A dictionary containing the resized prompt and additional information. + """ + resized_prompt = prompt + prompt_length = 0 + new_prompt_length = 0 + + if prompt: + prompt_length = len(self.tokenizer.tokenize(prompt)) + if (prompt_length + self.max_length) <= self.model_max_length: + resized_prompt = prompt + new_prompt_length = prompt_length + else: + tokenized_payload = self.tokenizer.tokenize(prompt) + resized_prompt = self.tokenizer.convert_tokens_to_string( + tokenized_payload[: self.model_max_length - self.max_length] + ) + new_prompt_length = len(tokenized_payload[: self.model_max_length - self.max_length]) + + return { + "resized_prompt": resized_prompt, + "prompt_length": prompt_length, + "new_prompt_length": new_prompt_length, + "model_max_length": self.model_max_length, + "max_length": self.max_length, + } diff --git a/test/prompt/invocation_layer/test_cohere.py b/test/prompt/invocation_layer/test_cohere.py new file mode 100644 index 000000000..6e481c755 --- /dev/null +++ b/test/prompt/invocation_layer/test_cohere.py @@ -0,0 +1,180 @@ +import unittest +from unittest.mock import patch, MagicMock + +import pytest + +from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler +from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer + + +@pytest.mark.unit +def test_default_constructor(): + """ + Test that the default constructor sets the correct values + """ + + layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key") + + assert layer.api_key == "some_fake_key" + assert layer.max_length == 100 + assert layer.model_input_kwargs == {} + assert layer.prompt_handler.model_max_length == 4096 + + layer = CohereInvocationLayer(model_name_or_path="base", api_key="some_fake_key") + assert layer.api_key == "some_fake_key" + assert layer.max_length == 100 + assert layer.model_input_kwargs == {} + assert layer.prompt_handler.model_max_length == 2048 + + +@pytest.mark.unit +def test_constructor_with_model_kwargs(): + """ + Test that model_kwargs are correctly set in the constructor + and that model_kwargs_rejected are correctly filtered out + """ + model_kwargs = {"temperature": 0.7, "end_sequences": ["end"], "stream": True} + model_kwargs_rejected = {"fake_param": 0.7, "another_fake_param": 1} + layer = CohereInvocationLayer( + model_name_or_path="command", api_key="some_fake_key", **model_kwargs, **model_kwargs_rejected + ) + assert layer.model_input_kwargs == model_kwargs + assert len(model_kwargs_rejected) == 2 + + +@pytest.mark.unit +def test_invoke_with_no_kwargs(): + """ + Test that invoke raises an error if no prompt is provided + """ + layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key") + with pytest.raises(ValueError) as e: + layer.invoke() + assert e.match("No prompt provided.") + + +@pytest.mark.unit +def test_invoke_with_stop_words(): + """ + Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer + """ + stop_words = ["but", "not", "bye"] + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response, need to return a list of dicts + mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') + + layer.invoke(prompt="Tell me hello", stop_words=stop_words) + + assert mock_post.called + + # Check if stop_words are passed to _post as stop parameter + called_args, _ = mock_post.call_args + assert "end_sequences" in called_args[0] + assert called_args[0]["end_sequences"] == stop_words + + +@pytest.mark.unit +@pytest.mark.parametrize("using_constructor", [True, False]) +@pytest.mark.parametrize("stream", [True, False]) +def test_streaming_stream_param(using_constructor, stream): + """ + Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer + """ + if using_constructor: + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream) + else: + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response, need to return a list of dicts + mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') + + if using_constructor: + layer.invoke(prompt="Tell me hello") + else: + layer.invoke(prompt="Tell me hello", stream=stream) + + assert mock_post.called + + # Check if stop_words are passed to _post as stop parameter + called_args, called_kwargs = mock_post.call_args + + # stream is always passed to _post + assert "stream" in called_kwargs + + # Check if stream is True, then stream is passed as True to _post + if stream: + assert called_kwargs["stream"] + # Check if stream is False, then stream is passed as False to _post + else: + assert not called_kwargs["stream"] + + +@pytest.mark.unit +@pytest.mark.parametrize("using_constructor", [True, False]) +@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None]) +def test_streaming_stream_handler_param(using_constructor, stream_handler): + """ + Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer + """ + if using_constructor: + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler) + else: + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + + with unittest.mock.patch( + "haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post" + ) as mock_post, unittest.mock.patch( + "haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response" + ) as mock_post_stream: + # Mock the response, need to return a list of dicts + mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') + + if using_constructor: + layer.invoke(prompt="Tell me hello") + else: + layer.invoke(prompt="Tell me hello", stream_handler=stream_handler) + + assert mock_post.called + + # Check if stop_words are passed to _post as stop parameter + called_args, called_kwargs = mock_post.call_args + + # stream is always passed to _post + assert "stream" in called_kwargs + + # if stream_handler is used then stream is always True + if stream_handler: + assert called_kwargs["stream"] + # and stream_handler is passed as an instance of TokenStreamingHandler + called_args, called_kwargs = mock_post_stream.call_args + assert "stream_handler" in called_kwargs + assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler) + # if stream_handler is not used then stream is always False + else: + assert not called_kwargs["stream"] + + +@pytest.mark.unit +def test_supports(): + """ + Test that supports returns True correctly for CohereInvocationLayer + """ + # See command and generate models at https://docs.cohere.com/docs/models + # doesn't support fake model + assert not CohereInvocationLayer.supports("fake_model", api_key="fake_key") + + # supports cohere command with api_key + assert CohereInvocationLayer.supports("command", api_key="fake_key") + + # supports cohere command-light with api_key + assert CohereInvocationLayer.supports("command-light", api_key="fake_key") + + # supports cohere base with api_key + assert CohereInvocationLayer.supports("base", api_key="fake_key") + + assert CohereInvocationLayer.supports("base-light", api_key="fake_key") + + # doesn't support other models that have base substring only i.e. google/flan-t5-base + assert not CohereInvocationLayer.supports("google/flan-t5-base") diff --git a/test/prompt/test_handlers.py b/test/prompt/test_handlers.py new file mode 100644 index 000000000..26173d871 --- /dev/null +++ b/test/prompt/test_handlers.py @@ -0,0 +1,60 @@ +import pytest + +from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler + + +@pytest.mark.integration +def test_prompt_handler_basics(): + handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10) + assert callable(handler) + + handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20) + assert handler.max_length == 100 + + +@pytest.mark.integration +def test_gpt2_prompt_handler(): + # test gpt2 BPE based tokenizer + handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10) + + # test no resize + assert handler("This is a test") == { + "prompt_length": 4, + "resized_prompt": "This is a test", + "max_length": 10, + "model_max_length": 20, + "new_prompt_length": 4, + } + + # test resize + assert handler("This is a prompt that will be resized because it is longer than allowed") == { + "prompt_length": 15, + "resized_prompt": "This is a prompt that will be resized because", + "max_length": 10, + "model_max_length": 20, + "new_prompt_length": 10, + } + + +@pytest.mark.integration +def test_flan_prompt_handler(): + # test google/flan-t5-xxl tokenizer + handler = DefaultPromptHandler(model_name_or_path="google/flan-t5-xxl", model_max_length=20, max_length=10) + + # test no resize + assert handler("This is a test") == { + "prompt_length": 5, + "resized_prompt": "This is a test", + "max_length": 10, + "model_max_length": 20, + "new_prompt_length": 5, + } + + # test resize + assert handler("This is a prompt that will be resized because it is longer than allowed") == { + "prompt_length": 17, + "resized_prompt": "This is a prompt that will be re", + "max_length": 10, + "model_max_length": 20, + "new_prompt_length": 10, + }