diff --git a/haystack/nodes/prompt/invocation_layer/cohere.py b/haystack/nodes/prompt/invocation_layer/cohere.py index 1f8e3d100..374d41894 100644 --- a/haystack/nodes/prompt/invocation_layer/cohere.py +++ b/haystack/nodes/prompt/invocation_layer/cohere.py @@ -155,7 +155,7 @@ class CohereInvocationLayer(PromptModelInvocationLayer): **kwargs, ) -> requests.Response: """ - Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST + Post data to the Cohere inference model. It takes in a prompt and returns a list of responses using a REST invocation. :param data: The data to be sent to the model. :param stream: Whether to stream the response. @@ -192,19 +192,21 @@ class CohereInvocationLayer(PromptModelInvocationLayer): return response def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: - # the prompt for this model will be of the type str - resize_info = self.prompt_handler(prompt) # type: ignore + if isinstance(prompt, List): + raise ValueError("Cohere invocation layer doesn't support a dictionary as prompt") + + resize_info = self.prompt_handler(prompt) if resize_info["prompt_length"] != resize_info["new_prompt_length"]: logger.warning( "The prompt has been truncated from %s tokens to %s tokens so that the prompt length and " "answer length (%s tokens) fit within the max token limit (%s tokens). " - "Shorten the prompt to prevent it from being cut off", + "Reduce the length of the prompt to prevent it from being cut off.", resize_info["prompt_length"], max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore resize_info["max_length"], resize_info["model_max_length"], ) - return prompt + return str(resize_info["resized_prompt"]) @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: diff --git a/test/prompt/invocation_layer/test_cohere.py b/test/prompt/invocation_layer/test_cohere.py index 6e481c755..3beb1176d 100644 --- a/test/prompt/invocation_layer/test_cohere.py +++ b/test/prompt/invocation_layer/test_cohere.py @@ -178,3 +178,42 @@ def test_supports(): # doesn't support other models that have base substring only i.e. google/flan-t5-base assert not CohereInvocationLayer.supports("google/flan-t5-base") + + +@pytest.mark.unit +def test_ensure_token_limit_fails_if_called_with_list(): + layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key") + with pytest.raises(ValueError): + layer._ensure_token_limit(prompt=[]) + + +@pytest.mark.unit +def test_ensure_token_limit_with_small_max_length(caplog): + layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10) + res = layer._ensure_token_limit(prompt="Short prompt") + + assert res == "Short prompt" + assert not caplog.records + + res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt") + assert res == "This is a very very very very very much longer prompt" + assert not caplog.records + + +@pytest.mark.unit +def test_ensure_token_limit_with_huge_max_length(caplog): + layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090) + res = layer._ensure_token_limit(prompt="Short prompt") + + assert res == "Short prompt" + assert not caplog.records + + res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt") + assert res == "This is a very very very" + assert len(caplog.records) == 1 + expected_message_log = ( + "The prompt has been truncated from 11 tokens to 6 tokens so that the prompt length and " + "answer length (4090 tokens) fit within the max token limit (4096 tokens). " + "Reduce the length of the prompt to prevent it from being cut off." + ) + assert caplog.records[0].message == expected_message_log