diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index 277d8d4a4..a0a48d902 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -195,6 +195,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): """ output: List[Dict[str, str]] = [] stop_words = kwargs.pop("stop_words", None) + top_k = kwargs.pop("top_k", None) # either stream is True (will use default handler) or stream_handler is provided for custom handler stream = kwargs.get("stream", self.stream) stream_handler = kwargs.get("stream_handler", self.stream_handler) @@ -238,21 +239,12 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): if stop_words: sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words, device=self.pipe.device) model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw]) - - if "num_beams" in model_input_kwargs: - num_beams = model_input_kwargs["num_beams"] - if ( - "num_return_sequences" in model_input_kwargs - and model_input_kwargs["num_return_sequences"] > num_beams - ): - num_return_sequences = model_input_kwargs["num_return_sequences"] - logger.warning( - "num_return_sequences %s should not be larger than num_beams %s, hence setting it equal to num_beams", - num_return_sequences, - num_beams, - ) - model_input_kwargs["num_return_sequences"] = num_beams - + if top_k: + model_input_kwargs["num_return_sequences"] = top_k + if "num_beams" not in model_input_kwargs or model_input_kwargs["num_beams"] < top_k: + if "num_beams" in model_input_kwargs: + logger.warning("num_beams should not be less than top_k, hence setting it to %s", top_k) + model_input_kwargs["num_beams"] = top_k # max_new_tokens is used for text-generation and max_length for text2text-generation if is_text_generation: model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length) diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index 2ff1f52e1..e5ca043ef 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -205,48 +205,6 @@ def test_ensure_token_limit_negative(caplog): assert caplog.records[0].message == expected_message -@pytest.mark.unit -def test_num_return_sequences_no_larger_than_num_beams(mock_pipeline, mock_get_task, caplog): - """ - Test that num_return_sequences cannot be larger than num_beams and that a warning is logged - """ - - layer = HFLocalInvocationLayer("google/flan-t5-base") - - with patch.object(layer.pipe, "run_single", MagicMock()): - layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 5, "num_return_sequences": 8}) - - expected_message = ( - "num_return_sequences 8 should not be larger than num_beams 5, hence setting it equal to num_beams" - ) - # check that the warning is logged - assert caplog.records[0].message == expected_message - - # check that num_return_sequences is set to num_beams - _, kwargs = layer.pipe.call_args - assert kwargs["num_beams"] == 5 - assert kwargs["num_return_sequences"] == 5 - - -@pytest.mark.unit -def test_num_beams_larger_than_num_return_sequences(mock_pipeline, mock_get_task, caplog): - """ - Test that num_beams can be larger than num_return_sequences and that no warning is logged - """ - layer = HFLocalInvocationLayer("google/flan-t5-base") - - with patch.object(layer.pipe, "run_single", MagicMock()): - layer.invoke(prompt="What does 42 mean?", generation_kwargs={"num_beams": 8, "num_return_sequences": 5}) - - # check that no warning is logged - assert not caplog.records - - # check that num_return_sequences remains unchanged - _, kwargs = layer.pipe.call_args - assert kwargs["num_beams"] == 8 - assert kwargs["num_return_sequences"] == 5 - - @pytest.mark.unit def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task): """