From 2be1a68fce0bca49d1b0926e9e0d0a14a4dafae0 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 25 Apr 2023 22:08:06 +0800 Subject: [PATCH] fix: Allow to set `num_beams` in HFInvocationLayer (#4731) Signed-off-by: Wang, Yi A Co-authored-by: bogdankostic --- haystack/nodes/prompt/invocation_layer/hugging_face.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index bd7d2ed34..7a85becdc 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -193,7 +193,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw]) if top_k: model_input_kwargs["num_return_sequences"] = top_k - model_input_kwargs["num_beams"] = top_k + if "num_beams" not in model_input_kwargs or model_input_kwargs["num_beams"] < top_k: + if "num_beams" in model_input_kwargs: + logger.warning("num_beams should not be less than top_k, hence setting it to %s", top_k) + model_input_kwargs["num_beams"] = top_k # max_new_tokens is used for text-generation and max_length for text2text-generation if is_text_generation: model_input_kwargs["max_new_tokens"] = self.max_length