From 8511b8cd79c5caaa190e44c4293d73ea6ef4a24c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:29:38 +0100 Subject: [PATCH] feat: `HuggingFaceLocalGenerator`- allow passing `generation_kwargs` in `run` method (#6220) * allow custom generation_kwargs in run * reno * make pylint ignore too-many-public-methods --- .../hugging_face/hugging_face_local.py | 14 ++++++++++++-- ...eration-kwargs-in-run-2bde10d398a3712a.yaml | 5 +++++ .../test_hugging_face_local_generator.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml diff --git a/haystack/preview/components/generators/hugging_face/hugging_face_local.py b/haystack/preview/components/generators/hugging_face/hugging_face_local.py index 8f1d9874f..3837b5e51 100644 --- a/haystack/preview/components/generators/hugging_face/hugging_face_local.py +++ b/haystack/preview/components/generators/hugging_face/hugging_face_local.py @@ -206,14 +206,24 @@ class HuggingFaceLocalGenerator: ) @component.output_types(replies=List[str]) - def run(self, prompt: str): + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Run the text generation model on the given prompt. + + :param prompt: A string representing the prompt. + :param generation_kwargs: Additional keyword arguments for text generation. + :return: A dictionary containing the generated replies. + """ if self.pipeline is None: raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") if not prompt: return {"replies": []} - output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **self.generation_kwargs) + # merge generation kwargs from init method with those from run method + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) replies = [o["generated_text"] for o in output if "generated_text" in o] if self.stop_words: diff --git a/releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml b/releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml new file mode 100644 index 000000000..cf605f9c3 --- /dev/null +++ b/releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Allow passing `generation_kwargs` in the `run` method of the `HuggingFaceLocalGenerator`. + This makes this common operation faster. diff --git a/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py b/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py index 17f03b1ed..54690c9af 100644 --- a/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py +++ b/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-public-methods from unittest.mock import patch, Mock import pytest @@ -242,6 +243,23 @@ class TestHuggingFaceLocalGenerator: assert results == {"replies": []} + @pytest.mark.unit + def test_run_with_generation_kwargs(self): + generator = HuggingFaceLocalGenerator( + model_name_or_path="google/flan-t5-base", + task="text2text-generation", + generation_kwargs={"max_new_tokens": 100}, + ) + + # create the pipeline object (simulating the warm_up) + generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}]) + + generator.run(prompt="irrelevant", generation_kwargs={"max_new_tokens": 200, "temperature": 0.5}) + + generator.pipeline.assert_called_once_with( + "irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None + ) + @pytest.mark.unit def test_run_fails_without_warm_up(self): generator = HuggingFaceLocalGenerator(