mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 15:59:14 +00:00
Fix CohereInvocationLayer _ensure_token_limit not returning resized (#4978)
prompt
This commit is contained in:
parent
00bee17b79
commit
524d2cba36
@ -155,7 +155,7 @@ class CohereInvocationLayer(PromptModelInvocationLayer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> requests.Response:
|
) -> requests.Response:
|
||||||
"""
|
"""
|
||||||
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST
|
Post data to the Cohere inference model. It takes in a prompt and returns a list of responses using a REST
|
||||||
invocation.
|
invocation.
|
||||||
:param data: The data to be sent to the model.
|
:param data: The data to be sent to the model.
|
||||||
:param stream: Whether to stream the response.
|
:param stream: Whether to stream the response.
|
||||||
@ -192,19 +192,21 @@ class CohereInvocationLayer(PromptModelInvocationLayer):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||||
# the prompt for this model will be of the type str
|
if isinstance(prompt, List):
|
||||||
resize_info = self.prompt_handler(prompt) # type: ignore
|
raise ValueError("Cohere invocation layer doesn't support a dictionary as prompt")
|
||||||
|
|
||||||
|
resize_info = self.prompt_handler(prompt)
|
||||||
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
|
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
|
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
|
||||||
"answer length (%s tokens) fit within the max token limit (%s tokens). "
|
"answer length (%s tokens) fit within the max token limit (%s tokens). "
|
||||||
"Shorten the prompt to prevent it from being cut off",
|
"Reduce the length of the prompt to prevent it from being cut off.",
|
||||||
resize_info["prompt_length"],
|
resize_info["prompt_length"],
|
||||||
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
|
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
|
||||||
resize_info["max_length"],
|
resize_info["max_length"],
|
||||||
resize_info["model_max_length"],
|
resize_info["model_max_length"],
|
||||||
)
|
)
|
||||||
return prompt
|
return str(resize_info["resized_prompt"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||||
|
|||||||
@ -178,3 +178,42 @@ def test_supports():
|
|||||||
|
|
||||||
# doesn't support other models that have base substring only i.e. google/flan-t5-base
|
# doesn't support other models that have base substring only i.e. google/flan-t5-base
|
||||||
assert not CohereInvocationLayer.supports("google/flan-t5-base")
|
assert not CohereInvocationLayer.supports("google/flan-t5-base")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ensure_token_limit_fails_if_called_with_list():
|
||||||
|
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
layer._ensure_token_limit(prompt=[])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ensure_token_limit_with_small_max_length(caplog):
|
||||||
|
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10)
|
||||||
|
res = layer._ensure_token_limit(prompt="Short prompt")
|
||||||
|
|
||||||
|
assert res == "Short prompt"
|
||||||
|
assert not caplog.records
|
||||||
|
|
||||||
|
res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt")
|
||||||
|
assert res == "This is a very very very very very much longer prompt"
|
||||||
|
assert not caplog.records
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ensure_token_limit_with_huge_max_length(caplog):
|
||||||
|
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090)
|
||||||
|
res = layer._ensure_token_limit(prompt="Short prompt")
|
||||||
|
|
||||||
|
assert res == "Short prompt"
|
||||||
|
assert not caplog.records
|
||||||
|
|
||||||
|
res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt")
|
||||||
|
assert res == "This is a very very very"
|
||||||
|
assert len(caplog.records) == 1
|
||||||
|
expected_message_log = (
|
||||||
|
"The prompt has been truncated from 11 tokens to 6 tokens so that the prompt length and "
|
||||||
|
"answer length (4090 tokens) fit within the max token limit (4096 tokens). "
|
||||||
|
"Reduce the length of the prompt to prevent it from being cut off."
|
||||||
|
)
|
||||||
|
assert caplog.records[0].message == expected_message_log
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user