feat: HuggingFaceLocalGenerator- allow passing generation_kwargs in run method (#6220)

* allow custom generation_kwargs in run

* reno

* make pylint ignore too-many-public-methods
This commit is contained in:
Stefano Fiorucci 2023-11-02 15:29:38 +01:00 committed by GitHub
parent f2db68ef0b
commit 8511b8cd79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 2 deletions

View File

@ -206,14 +206,24 @@ class HuggingFaceLocalGenerator:
)
@component.output_types(replies=List[str])
def run(self, prompt: str):
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run the text generation model on the given prompt.
:param prompt: A string representing the prompt.
:param generation_kwargs: Additional keyword arguments for text generation.
:return: A dictionary containing the generated replies.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
if not prompt:
return {"replies": []}
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **self.generation_kwargs)
# merge generation kwargs from init method with those from run method
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
replies = [o["generated_text"] for o in output if "generated_text" in o]
if self.stop_words:

View File

@ -0,0 +1,5 @@
---
preview:
- |
Allow passing `generation_kwargs` in the `run` method of the `HuggingFaceLocalGenerator`.
This makes this common operation faster.

View File

@ -1,3 +1,4 @@
# pylint: disable=too-many-public-methods
from unittest.mock import patch, Mock
import pytest
@ -242,6 +243,23 @@ class TestHuggingFaceLocalGenerator:
assert results == {"replies": []}
@pytest.mark.unit
def test_run_with_generation_kwargs(self):
generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base",
task="text2text-generation",
generation_kwargs={"max_new_tokens": 100},
)
# create the pipeline object (simulating the warm_up)
generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}])
generator.run(prompt="irrelevant", generation_kwargs={"max_new_tokens": 200, "temperature": 0.5})
generator.pipeline.assert_called_once_with(
"irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None
)
@pytest.mark.unit
def test_run_fails_without_warm_up(self):
generator = HuggingFaceLocalGenerator(