mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-17 13:07:42 +00:00
feat: HuggingFaceLocalGenerator
- allow passing generation_kwargs
in run
method (#6220)
* allow custom generation_kwargs in run * reno * make pylint ignore too-many-public-methods
This commit is contained in:
parent
f2db68ef0b
commit
8511b8cd79
@ -206,14 +206,24 @@ class HuggingFaceLocalGenerator:
|
||||
)
|
||||
|
||||
@component.output_types(replies=List[str])
|
||||
def run(self, prompt: str):
|
||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Run the text generation model on the given prompt.
|
||||
|
||||
:param prompt: A string representing the prompt.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:return: A dictionary containing the generated replies.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
if not prompt:
|
||||
return {"replies": []}
|
||||
|
||||
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **self.generation_kwargs)
|
||||
# merge generation kwargs from init method with those from run method
|
||||
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
|
||||
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
||||
|
||||
if self.stop_words:
|
||||
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Allow passing `generation_kwargs` in the `run` method of the `HuggingFaceLocalGenerator`.
|
||||
This makes this common operation faster.
|
@ -1,3 +1,4 @@
|
||||
# pylint: disable=too-many-public-methods
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
@ -242,6 +243,23 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
assert results == {"replies": []}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_generation_kwargs(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
)
|
||||
|
||||
# create the pipeline object (simulating the warm_up)
|
||||
generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}])
|
||||
|
||||
generator.run(prompt="irrelevant", generation_kwargs={"max_new_tokens": 200, "temperature": 0.5})
|
||||
|
||||
generator.pipeline.assert_called_once_with(
|
||||
"irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_fails_without_warm_up(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
|
Loading…
x
Reference in New Issue
Block a user