diff --git a/haystack/preview/components/generators/hugging_face/hugging_face_local.py b/haystack/preview/components/generators/hugging_face/hugging_face_local.py index 8f1d9874f..3837b5e51 100644 --- a/haystack/preview/components/generators/hugging_face/hugging_face_local.py +++ b/haystack/preview/components/generators/hugging_face/hugging_face_local.py @@ -206,14 +206,24 @@ class HuggingFaceLocalGenerator: ) @component.output_types(replies=List[str]) - def run(self, prompt: str): + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Run the text generation model on the given prompt. + + :param prompt: A string representing the prompt. + :param generation_kwargs: Additional keyword arguments for text generation. + :return: A dictionary containing the generated replies. + """ if self.pipeline is None: raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") if not prompt: return {"replies": []} - output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **self.generation_kwargs) + # merge generation kwargs from init method with those from run method + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) replies = [o["generated_text"] for o in output if "generated_text" in o] if self.stop_words: diff --git a/releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml b/releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml new file mode 100644 index 000000000..cf605f9c3 --- /dev/null +++ b/releasenotes/notes/hflocalgenerator-generation-kwargs-in-run-2bde10d398a3712a.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Allow passing `generation_kwargs` in the `run` method of the `HuggingFaceLocalGenerator`. + This makes this common operation faster. diff --git a/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py b/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py index 17f03b1ed..54690c9af 100644 --- a/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py +++ b/test/preview/components/generators/hugging_face/test_hugging_face_local_generator.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-public-methods from unittest.mock import patch, Mock import pytest @@ -242,6 +243,23 @@ class TestHuggingFaceLocalGenerator: assert results == {"replies": []} + @pytest.mark.unit + def test_run_with_generation_kwargs(self): + generator = HuggingFaceLocalGenerator( + model_name_or_path="google/flan-t5-base", + task="text2text-generation", + generation_kwargs={"max_new_tokens": 100}, + ) + + # create the pipeline object (simulating the warm_up) + generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}]) + + generator.run(prompt="irrelevant", generation_kwargs={"max_new_tokens": 200, "temperature": 0.5}) + + generator.pipeline.assert_called_once_with( + "irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None + ) + @pytest.mark.unit def test_run_fails_without_warm_up(self): generator = HuggingFaceLocalGenerator(