Fix CohereInvocationLayer _ensure_token_limit not returning resized (#4978)

prompt
This commit is contained in:
Silvano Cerza 2023-05-23 17:58:01 +02:00 committed by GitHub
parent 00bee17b79
commit 524d2cba36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 5 deletions

View File

@ -155,7 +155,7 @@ class CohereInvocationLayer(PromptModelInvocationLayer):
**kwargs,
) -> requests.Response:
"""
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST
Post data to the Cohere inference model. It takes in a prompt and returns a list of responses using a REST
invocation.
:param data: The data to be sent to the model.
:param stream: Whether to stream the response.
@ -192,19 +192,21 @@ class CohereInvocationLayer(PromptModelInvocationLayer):
return response
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# the prompt for this model will be of the type str
resize_info = self.prompt_handler(prompt) # type: ignore
if isinstance(prompt, List):
raise ValueError("Cohere invocation layer doesn't support a dictionary as prompt")
resize_info = self.prompt_handler(prompt)
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
"answer length (%s tokens) fit within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off",
"Reduce the length of the prompt to prevent it from being cut off.",
resize_info["prompt_length"],
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
resize_info["max_length"],
resize_info["model_max_length"],
)
return prompt
return str(resize_info["resized_prompt"])
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:

View File

@ -178,3 +178,42 @@ def test_supports():
# doesn't support other models that have base substring only i.e. google/flan-t5-base
assert not CohereInvocationLayer.supports("google/flan-t5-base")
@pytest.mark.unit
def test_ensure_token_limit_fails_if_called_with_list():
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
with pytest.raises(ValueError):
layer._ensure_token_limit(prompt=[])
@pytest.mark.unit
def test_ensure_token_limit_with_small_max_length(caplog):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10)
res = layer._ensure_token_limit(prompt="Short prompt")
assert res == "Short prompt"
assert not caplog.records
res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt")
assert res == "This is a very very very very very much longer prompt"
assert not caplog.records
@pytest.mark.unit
def test_ensure_token_limit_with_huge_max_length(caplog):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090)
res = layer._ensure_token_limit(prompt="Short prompt")
assert res == "Short prompt"
assert not caplog.records
res = layer._ensure_token_limit(prompt="This is a very very very very very much longer prompt")
assert res == "This is a very very very"
assert len(caplog.records) == 1
expected_message_log = (
"The prompt has been truncated from 11 tokens to 6 tokens so that the prompt length and "
"answer length (4090 tokens) fit within the max token limit (4096 tokens). "
"Reduce the length of the prompt to prevent it from being cut off."
)
assert caplog.records[0].message == expected_message_log