fix: num_return_sequences should be less than num_beams, not top_k (#5280)

* formatting

* remove top_k variable

* add pytest

* add numbers

* string formatting

* fix formatting

* revert

* extend tests with assertions for num_return_sequences

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Fanli Lin 2023-07-11 18:20:21 +08:00 committed by GitHub
parent 41668f26d6
commit 514f93a6eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 7 deletions

View File

@ -196,7 +196,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
""" """
output: List[Dict[str, str]] = [] output: List[Dict[str, str]] = []
stop_words = kwargs.pop("stop_words", None) stop_words = kwargs.pop("stop_words", None)
top_k = kwargs.pop("top_k", None)
# either stream is True (will use default handler) or stream_handler is provided for custom handler # either stream is True (will use default handler) or stream_handler is provided for custom handler
stream = kwargs.get("stream", self.stream) stream = kwargs.get("stream", self.stream)
stream_handler = kwargs.get("stream_handler", self.stream_handler) stream_handler = kwargs.get("stream_handler", self.stream_handler)
@ -241,12 +240,21 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
if stop_words: if stop_words:
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device) sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device)
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw]) model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
if top_k:
model_input_kwargs["num_return_sequences"] = top_k
if "num_beams" not in model_input_kwargs or model_input_kwargs["num_beams"] < top_k:
if "num_beams" in model_input_kwargs: if "num_beams" in model_input_kwargs:
logger.warning("num_beams should not be less than top_k, hence setting it to %s", top_k) num_beams = model_input_kwargs["num_beams"]
model_input_kwargs["num_beams"] = top_k if (
"num_return_sequences" in model_input_kwargs
and model_input_kwargs["num_return_sequences"] > num_beams
):
num_return_sequences = model_input_kwargs["num_return_sequences"]
logger.warning(
"num_return_sequences %s should not be larger than num_beams %s, hence setting it equal to num_beams",
num_return_sequences,
num_beams,
)
model_input_kwargs["num_return_sequences"] = num_beams
# max_new_tokens is used for text-generation and max_length for text2text-generation # max_new_tokens is used for text-generation and max_length for text2text-generation
if is_text_generation: if is_text_generation:
model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length) model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length)

View File

@ -204,6 +204,48 @@ def test_ensure_token_limit_negative(caplog):
assert caplog.records[0].message == expected_message assert caplog.records[0].message == expected_message
@pytest.mark.unit
def test_num_return_sequences_no_larger_than_num_beams(mock_pipeline, mock_get_task, caplog):
"""
Test that num_return_sequences cannot be larger than num_beams and that a warning is logged
"""
layer = HFLocalInvocationLayer("google/flan-t5-base")
with patch.object(layer.pipe, "run_single", MagicMock()):
layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 5, "num_return_sequences": 8})
expected_message = (
"num_return_sequences 8 should not be larger than num_beams 5, hence setting it equal to num_beams"
)
# check that the warning is logged
assert caplog.records[0].message == expected_message
# check that num_return_sequences is set to num_beams
_, kwargs = layer.pipe.call_args
assert kwargs["num_beams"] == 5
assert kwargs["num_return_sequences"] == 5
@pytest.mark.unit
def test_num_beams_larger_than_num_return_sequences(mock_pipeline, mock_get_task, caplog):
"""
Test that num_beams can be larger than num_return_sequences and that no warning is logged
"""
layer = HFLocalInvocationLayer("google/flan-t5-base")
with patch.object(layer.pipe, "run_single", MagicMock()):
layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 8, "num_return_sequences": 5})
# check that no warning is logged
assert not caplog.records
# check that num_return_sequences remains unchanged
_, kwargs = layer.pipe.call_args
assert kwargs["num_beams"] == 8
assert kwargs["num_return_sequences"] == 5
@pytest.mark.unit @pytest.mark.unit
def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task): def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
""" """