mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-03 10:26:14 +00:00
feat: HuggingFaceLocalGenerator - first implementation (#6022)
* draft * still a raw draft * still a raw draft * improvements * minimal impl ok * tests * reno * better language * examples of generation_kwargs * incorporate feedback * lg and format updates * don't save valid str tokens * fix style --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
parent
41fd0c5458
commit
fbd22bc1e9
@ -0,0 +1,4 @@
|
||||
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
|
||||
|
||||
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator"]
|
||||
@ -0,0 +1,148 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from copy import deepcopy
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import model_info
|
||||
from transformers import pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceLocalGenerator:
|
||||
"""
|
||||
Generator based on a Hugging Face model.
|
||||
This component provides an interface to generate text using a Hugging Face model that runs locally.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.preview.components.generators.hugging_face import HuggingFaceLocalGenerator
|
||||
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-large",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={
|
||||
"max_new_tokens": 100,
|
||||
"temperature": 0.9,
|
||||
})
|
||||
|
||||
print(generator.run("Who is the best American actor?"))
|
||||
# {'replies': ['John Cusack']}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "google/flan-t5-base",
|
||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||
device: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: The name or path of a Hugging Face model for text generation,
|
||||
for example, "google/flan-t5-large".
|
||||
If the model is also specified in the `pipeline_kwargs`, this parameter will be ignored.
|
||||
:param task: The task for the Hugging Face pipeline.
|
||||
Possible values are "text-generation" and "text2text-generation".
|
||||
Generally, decoder-only models like GPT support "text-generation",
|
||||
while encoder-decoder models like T5 support "text2text-generation".
|
||||
If the task is also specified in the `pipeline_kwargs`, this parameter will be ignored.
|
||||
If not specified, the component will attempt to infer the task from the model name,
|
||||
calling the Hugging Face Hub API.
|
||||
:param device: The device on which the model is loaded. (e.g., "cpu", "cuda:0").
|
||||
If `device` or `device_map` is specified in the `pipeline_kwargs`, this parameter will be ignored.
|
||||
:param token: The token to use as HTTP bearer authorization for remote files.
|
||||
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
|
||||
If the token is also specified in the `pipeline_kwargs`, this parameter will be ignored.
|
||||
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
|
||||
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
|
||||
See Hugging Face's documentation for more information:
|
||||
- https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation
|
||||
- https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
|
||||
:param pipeline_kwargs: Dictionary containing keyword arguments used to initialize the pipeline.
|
||||
These keyword arguments provide fine-grained control over the pipeline.
|
||||
In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters.
|
||||
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
|
||||
for more information on the available kwargs.
|
||||
In this dictionary, you can also include `model_kwargs` to specify the kwargs
|
||||
for model initialization:
|
||||
https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
pipeline_kwargs = pipeline_kwargs or {}
|
||||
generation_kwargs = generation_kwargs or {}
|
||||
|
||||
# check if the pipeline_kwargs contain the essential parameters
|
||||
# otherwise, populate them with values from other init parameters
|
||||
pipeline_kwargs.setdefault("model", model_name_or_path)
|
||||
pipeline_kwargs.setdefault("token", token)
|
||||
if device is not None and "device" not in pipeline_kwargs and "device_map" not in pipeline_kwargs:
|
||||
pipeline_kwargs["device"] = device
|
||||
|
||||
# task identification and validation
|
||||
if task is None:
|
||||
if "task" in pipeline_kwargs:
|
||||
task = pipeline_kwargs["task"]
|
||||
elif isinstance(pipeline_kwargs["model"], str):
|
||||
task = model_info(pipeline_kwargs["model"], token=pipeline_kwargs["token"]).pipeline_tag
|
||||
|
||||
if task not in SUPPORTED_TASKS:
|
||||
raise ValueError(
|
||||
f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}."
|
||||
)
|
||||
pipeline_kwargs["task"] = task
|
||||
|
||||
# if not specified, set return_full_text to False for text-generation
|
||||
# only generated text is returned (excluding prompt)
|
||||
if task == "text-generation":
|
||||
generation_kwargs.setdefault("return_full_text", False)
|
||||
|
||||
self.pipeline_kwargs = pipeline_kwargs
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.pipeline = None
|
||||
|
||||
def warm_up(self):
|
||||
if self.pipeline is None:
|
||||
self.pipeline = pipeline(**self.pipeline_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
pipeline_kwargs_to_serialize = deepcopy(self.pipeline_kwargs)
|
||||
|
||||
# we don't want to serialize valid tokens
|
||||
if isinstance(pipeline_kwargs_to_serialize["token"], str):
|
||||
pipeline_kwargs_to_serialize["token"] = None
|
||||
|
||||
return default_to_dict(
|
||||
self, pipeline_kwargs=pipeline_kwargs_to_serialize, generation_kwargs=self.generation_kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
|
||||
def run(self, prompt: str):
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
replies = []
|
||||
if prompt:
|
||||
output = self.pipeline(prompt, **self.generation_kwargs)
|
||||
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
||||
|
||||
return {"replies": replies}
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Add a minimal version of HuggingFaceLocalGenerator, a component that can run
|
||||
Hugging Face models locally to generate text.
|
||||
@ -0,0 +1,258 @@
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
|
||||
|
||||
|
||||
class TestHuggingFaceLocalGenerator:
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.model_info")
|
||||
def test_init_default(self, model_info_mock):
|
||||
model_info_mock.return_value.pipeline_tag = "text2text-generation"
|
||||
generator = HuggingFaceLocalGenerator()
|
||||
|
||||
assert generator.pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
}
|
||||
assert generator.generation_kwargs == {}
|
||||
assert generator.pipeline is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_custom_token(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
|
||||
assert generator.pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": "test-token",
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_custom_device(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", device="cuda:0"
|
||||
)
|
||||
|
||||
assert generator.pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"device": "cuda:0",
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_task_parameter(self):
|
||||
generator = HuggingFaceLocalGenerator(task="text2text-generation")
|
||||
|
||||
assert generator.pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_task_in_pipeline_kwargs(self):
|
||||
generator = HuggingFaceLocalGenerator(pipeline_kwargs={"task": "text2text-generation"})
|
||||
|
||||
assert generator.pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.model_info")
|
||||
def test_init_task_inferred_from_model_name(self, model_info_mock):
|
||||
model_info_mock.return_value.pipeline_tag = "text2text-generation"
|
||||
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-base")
|
||||
|
||||
assert generator.pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_invalid_task(self):
|
||||
with pytest.raises(ValueError, match="is not supported."):
|
||||
HuggingFaceLocalGenerator(task="text-classification")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_pipeline_kwargs_override_other_parameters(self):
|
||||
"""
|
||||
pipeline_kwargs represent the main configuration of this component.
|
||||
If they are provided, they should override other init parameters.
|
||||
"""
|
||||
|
||||
pipeline_kwargs = {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
"device": "cuda:0",
|
||||
"token": "another-test-token",
|
||||
}
|
||||
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
device="cpu",
|
||||
token="test-token",
|
||||
pipeline_kwargs=pipeline_kwargs,
|
||||
)
|
||||
|
||||
assert generator.pipeline_kwargs == pipeline_kwargs
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_generation_kwargs(self):
|
||||
generator = HuggingFaceLocalGenerator(task="text2text-generation", generation_kwargs={"max_new_tokens": 100})
|
||||
|
||||
assert generator.generation_kwargs == {"max_new_tokens": 100}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_set_return_full_text(self):
|
||||
"""
|
||||
if not specified, return_full_text is set to False for text-generation task
|
||||
(only generated text is returned, excluding prompt)
|
||||
"""
|
||||
generator = HuggingFaceLocalGenerator(task="text-generation")
|
||||
|
||||
assert generator.generation_kwargs == {"return_full_text": False}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.model_info")
|
||||
def test_to_dict_default(self, model_info_mock):
|
||||
model_info_mock.return_value.pipeline_tag = "text2text-generation"
|
||||
|
||||
component = HuggingFaceLocalGenerator()
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"pipeline_kwargs": {"model": "google/flan-t5-base", "task": "text2text-generation", "token": None},
|
||||
"generation_kwargs": {},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="gpt2",
|
||||
task="text-generation",
|
||||
device="cuda:0",
|
||||
token="test-token",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"pipeline_kwargs": {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
"token": None, # we don't want serialize valid tokens
|
||||
"device": "cuda:0",
|
||||
},
|
||||
"generation_kwargs": {"max_new_tokens": 100, "return_full_text": False},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"pipeline_kwargs": {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
"token": "test-token",
|
||||
"device": "cuda:0",
|
||||
},
|
||||
"generation_kwargs": {"max_new_tokens": 100, "return_full_text": False},
|
||||
},
|
||||
}
|
||||
|
||||
component = HuggingFaceLocalGenerator.from_dict(data)
|
||||
|
||||
assert component.pipeline_kwargs == {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
"token": "test-token",
|
||||
"device": "cuda:0",
|
||||
}
|
||||
assert component.generation_kwargs == {"max_new_tokens": 100, "return_full_text": False}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
|
||||
def test_warm_up(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
pipeline_mock.assert_not_called()
|
||||
|
||||
generator.warm_up()
|
||||
|
||||
pipeline_mock.assert_called_once_with(
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
|
||||
def test_warm_up_doesn_reload(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
|
||||
pipeline_mock.assert_not_called()
|
||||
|
||||
generator.warm_up()
|
||||
generator.warm_up()
|
||||
|
||||
pipeline_mock.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
)
|
||||
|
||||
# create the pipeline object (simulating the warm_up)
|
||||
generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}])
|
||||
|
||||
results = generator.run(prompt="What's the capital of Italy?")
|
||||
|
||||
generator.pipeline.assert_called_once_with("What's the capital of Italy?", max_new_tokens=100)
|
||||
assert results == {"replies": ["Rome"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
|
||||
def test_run_empty_prompt(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
)
|
||||
|
||||
generator.warm_up()
|
||||
|
||||
results = generator.run(prompt="")
|
||||
|
||||
assert results == {"replies": []}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_fails_without_warm_up(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="The generation model has not been loaded."):
|
||||
generator.run(prompt="irrelevant")
|
||||
Loading…
x
Reference in New Issue
Block a user