mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 07:48:43 +00:00
Fix CohereInvocationLayer _ensure_token_limit not returning resized (#4978)
prompt
This commit is contained in:
parent
00bee17b79
commit
524d2cba36
@ -155,7 +155,7 @@ class CohereInvocationLayer(PromptModelInvocationLayer):
|
||||
**kwargs,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST
|
||||
Post data to the Cohere inference model. It takes in a prompt and returns a list of responses using a REST
|
||||
invocation.
|
||||
:param data: The data to be sent to the model.
|
||||
:param stream: Whether to stream the response.
|
||||
@ -192,19 +192,21 @@ class CohereInvocationLayer(PromptModelInvocationLayer):
|
||||
return response
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
# the prompt for this model will be of the type str
|
||||
resize_info = self.prompt_handler(prompt) # type: ignore
|
||||
if isinstance(prompt, List):
|
||||
raise ValueError("Cohere invocation layer doesn't support a dictionary as prompt")
|
||||
|
||||
resize_info = self.prompt_handler(prompt)
|
||||
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
|
||||
logger.warning(
|
||||
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
|
||||
"answer length (%s tokens) fit within the max token limit (%s tokens). "
|
||||
"Shorten the prompt to prevent it from being cut off",
|
||||
"Reduce the length of the prompt to prevent it from being cut off.",
|
||||
resize_info["prompt_length"],
|
||||
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
|
||||
resize_info["max_length"],
|
||||
resize_info["model_max_length"],
|
||||
)
|
||||
return prompt
|
||||
return str(resize_info["resized_prompt"])
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
|
||||
@ -178,3 +178,42 @@ def test_supports():
|
||||
|
||||
# doesn't support other models that have base substring only i.e. google/flan-t5-base
|
||||
assert not CohereInvocationLayer.supports("google/flan-t5-base")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_fails_if_called_with_list():
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
|
||||
with pytest.raises(ValueError):
|
||||
layer._ensure_token_limit(prompt=[])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_with_small_max_length(caplog):
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10)
|
||||
res = layer._ensure_token_limit(prompt="Short prompt")
|
||||
|
||||
assert res == "Short prompt"
|
||||
assert not caplog.records
|
||||
|
||||
res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt")
|
||||
assert res == "This is a very very very very very much longer prompt"
|
||||
assert not caplog.records
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_with_huge_max_length(caplog):
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090)
|
||||
res = layer._ensure_token_limit(prompt="Short prompt")
|
||||
|
||||
assert res == "Short prompt"
|
||||
assert not caplog.records
|
||||
|
||||
res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt")
|
||||
assert res == "This is a very very very"
|
||||
assert len(caplog.records) == 1
|
||||
expected_message_log = (
|
||||
"The prompt has been truncated from 11 tokens to 6 tokens so that the prompt length and "
|
||||
"answer length (4090 tokens) fit within the max token limit (4096 tokens). "
|
||||
"Reduce the length of the prompt to prevent it from being cut off."
|
||||
)
|
||||
assert caplog.records[0].message == expected_message_log
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user