mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-17 13:07:42 +00:00
feat: HuggingFaceLocalGenerator
- allow passing generation_kwargs
in run
method (#6220)
* allow custom generation_kwargs in run * reno * make pylint ignore too-many-public-methods
This commit is contained in:
parent
f2db68ef0b
commit
8511b8cd79
@ -206,14 +206,24 @@ class HuggingFaceLocalGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@component.output_types(replies=List[str])
|
@component.output_types(replies=List[str])
|
||||||
def run(self, prompt: str):
|
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||||
|
"""
|
||||||
|
Run the text generation model on the given prompt.
|
||||||
|
|
||||||
|
:param prompt: A string representing the prompt.
|
||||||
|
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||||
|
:return: A dictionary containing the generated replies.
|
||||||
|
"""
|
||||||
if self.pipeline is None:
|
if self.pipeline is None:
|
||||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return {"replies": []}
|
return {"replies": []}
|
||||||
|
|
||||||
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **self.generation_kwargs)
|
# merge generation kwargs from init method with those from run method
|
||||||
|
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
|
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
|
||||||
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
||||||
|
|
||||||
if self.stop_words:
|
if self.stop_words:
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
preview:
|
||||||
|
- |
|
||||||
|
Allow passing `generation_kwargs` in the `run` method of the `HuggingFaceLocalGenerator`.
|
||||||
|
This makes this common operation faster.
|
@ -1,3 +1,4 @@
|
|||||||
|
# pylint: disable=too-many-public-methods
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -242,6 +243,23 @@ class TestHuggingFaceLocalGenerator:
|
|||||||
|
|
||||||
assert results == {"replies": []}
|
assert results == {"replies": []}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_run_with_generation_kwargs(self):
|
||||||
|
generator = HuggingFaceLocalGenerator(
|
||||||
|
model_name_or_path="google/flan-t5-base",
|
||||||
|
task="text2text-generation",
|
||||||
|
generation_kwargs={"max_new_tokens": 100},
|
||||||
|
)
|
||||||
|
|
||||||
|
# create the pipeline object (simulating the warm_up)
|
||||||
|
generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}])
|
||||||
|
|
||||||
|
generator.run(prompt="irrelevant", generation_kwargs={"max_new_tokens": 200, "temperature": 0.5})
|
||||||
|
|
||||||
|
generator.pipeline.assert_called_once_with(
|
||||||
|
"irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_run_fails_without_warm_up(self):
|
def test_run_fails_without_warm_up(self):
|
||||||
generator = HuggingFaceLocalGenerator(
|
generator = HuggingFaceLocalGenerator(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user