fix: Allow to set num_beams in HFInvocationLayer (#4731)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Wang, Yi 2023-04-25 22:08:06 +08:00 committed by GitHub
parent 7fa3591f5f
commit 2be1a68fce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -193,7 +193,10 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
if top_k:
model_input_kwargs["num_return_sequences"] = top_k
model_input_kwargs["num_beams"] = top_k
if "num_beams" not in model_input_kwargs or model_input_kwargs["num_beams"] < top_k:
if "num_beams" in model_input_kwargs:
logger.warning("num_beams should not be less than top_k, hence setting it to %s", top_k)
model_input_kwargs["num_beams"] = top_k
# max_new_tokens is used for text-generation and max_length for text2text-generation
if is_text_generation:
model_input_kwargs["max_new_tokens"] = self.max_length