From 514f93a6eb575d376b21d22e32080fac62cf785f Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Jul 2023 18:20:21 +0800 Subject: [PATCH] fix: num_return_sequences should be less than num_beams, not top_k (#5280) * formatting * remove top_k variable * add pytest * add numbers * string formatting * fix formatting * revert * extend tests with assertions for num_return_sequences --------- Co-authored-by: Julian Risch --- .../prompt/invocation_layer/hugging_face.py | 22 ++++++---- .../invocation_layer/test_hugging_face.py | 42 +++++++++++++++++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index 86d93b6f1..9fac3ff3d 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -196,7 +196,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): """ output: List[Dict[str, str]] = [] stop_words = kwargs.pop("stop_words", None) - top_k = kwargs.pop("top_k", None) # either stream is True (will use default handler) or stream_handler is provided for custom handler stream = kwargs.get("stream", self.stream) stream_handler = kwargs.get("stream_handler", self.stream_handler) @@ -241,12 +240,21 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): if stop_words: sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device) model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw]) - if top_k: - model_input_kwargs["num_return_sequences"] = top_k - if "num_beams" not in model_input_kwargs or model_input_kwargs["num_beams"] < top_k: - if "num_beams" in model_input_kwargs: - logger.warning("num_beams should not be less than top_k, hence setting it to %s", top_k) - model_input_kwargs["num_beams"] = top_k + + if "num_beams" in model_input_kwargs: + num_beams = model_input_kwargs["num_beams"] + if ( + "num_return_sequences" in model_input_kwargs + and model_input_kwargs["num_return_sequences"] > num_beams + ): + num_return_sequences = model_input_kwargs["num_return_sequences"] + logger.warning( + "num_return_sequences %s should not be larger than num_beams %s, hence setting it equal to num_beams", + num_return_sequences, + num_beams, + ) + model_input_kwargs["num_return_sequences"] = num_beams + # max_new_tokens is used for text-generation and max_length for text2text-generation if is_text_generation: model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length) diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index 5152dac95..db30713d7 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -204,6 +204,48 @@ def test_ensure_token_limit_negative(caplog): assert caplog.records[0].message == expected_message +@pytest.mark.unit +def test_num_return_sequences_no_larger_than_num_beams(mock_pipeline, mock_get_task, caplog): + """ + Test that num_return_sequences cannot be larger than num_beams and that a warning is logged + """ + + layer = HFLocalInvocationLayer("google/flan-t5-base") + + with patch.object(layer.pipe, "run_single", MagicMock()): + layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 5, "num_return_sequences": 8}) + + expected_message = ( + "num_return_sequences 8 should not be larger than num_beams 5, hence setting it equal to num_beams" + ) + # check that the warning is logged + assert caplog.records[0].message == expected_message + + # check that num_return_sequences is set to num_beams + _, kwargs = layer.pipe.call_args + assert kwargs["num_beams"] == 5 + assert kwargs["num_return_sequences"] == 5 + + +@pytest.mark.unit +def test_num_beams_larger_than_num_return_sequences(mock_pipeline, mock_get_task, caplog): + """ + Test that num_beams can be larger than num_return_sequences and that no warning is logged + """ + layer = HFLocalInvocationLayer("google/flan-t5-base") + + with patch.object(layer.pipe, "run_single", MagicMock()): + layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 8, "num_return_sequences": 5}) + + # check that no warning is logged + assert not caplog.records + + # check that num_return_sequences remains unchanged + _, kwargs = layer.pipe.call_args + assert kwargs["num_beams"] == 8 + assert kwargs["num_return_sequences"] == 5 + + @pytest.mark.unit def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task): """