feat: Add TransformersTextRouter component (#7801)

* First pass at adding TransformerTextRouter

* Fix tests

* Add release notes

* Add optional labels param

* Add verification in the warm_up

* Fix tests

* Add labels to to_dict

* Feedback from review

* Add component to docs

* Added extra tests
This commit is contained in:
Sebastian Husch Lee 2024-06-05 15:28:53 +02:00 committed by GitHub
parent e6b8b7529b
commit d815c78198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 401 additions and 5 deletions

View File

@ -7,6 +7,7 @@ loaders:
"file_type_router", "file_type_router",
"metadata_router", "metadata_router",
"text_language_router", "text_language_router",
"transformers_text_router",
"zero_shot_text_router", "zero_shot_text_router",
] ]
ignore_when_discovered: ["__init__"] ignore_when_discovered: ["__init__"]

View File

@ -6,6 +6,7 @@ from haystack.components.routers.conditional_router import ConditionalRouter
from haystack.components.routers.file_type_router import FileTypeRouter from haystack.components.routers.file_type_router import FileTypeRouter
from haystack.components.routers.metadata_router import MetadataRouter from haystack.components.routers.metadata_router import MetadataRouter
from haystack.components.routers.text_language_router import TextLanguageRouter from haystack.components.routers.text_language_router import TextLanguageRouter
from haystack.components.routers.transformers_text_router import TransformersTextRouter
from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter
__all__ = [ __all__ = [
@ -14,4 +15,5 @@ __all__ = [
"TextLanguageRouter", "TextLanguageRouter",
"ConditionalRouter", "ConditionalRouter",
"TransformersZeroShotTextRouter", "TransformersZeroShotTextRouter",
"TransformersTextRouter",
] ]

View File

@ -0,0 +1,209 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
from transformers import AutoConfig, pipeline
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
deserialize_hf_model_kwargs,
resolve_hf_pipeline_kwargs,
serialize_hf_model_kwargs,
)
@component
class TransformersTextRouter:
"""
Routes a text input onto different output connections depending on which label it has been categorized into.
This is useful for routing queries to different models in a pipeline depending on their categorization.
The set of labels to be used for categorization are provided by the selected model.
Example usage in a query pipeline that routes english queries to a text generator optimized for english text and
german queries to a text generator optimized for german text.
```python
from haystack.core.pipeline import Pipeline
from haystack.components.routers import TransformersTextRouter
from haystack.components.builders import PromptBuilder
from haystack.components.generators import HuggingFaceLocalGenerator
p = Pipeline()
p.add_component(
instance=TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection"),
name="text_router"
)
p.add_component(
instance=PromptBuilder(template="Answer the question: {{query}}\nAnswer:"),
name="english_prompt_builder"
)
p.add_component(
instance=PromptBuilder(template="Beantworte die Frage: {{query}}\nAntwort:"),
name="german_prompt_builder"
)
p.add_component(
instance=HuggingFaceLocalGenerator(model="DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1"),
name="german_llm"
)
p.add_component(
instance=HuggingFaceLocalGenerator(model="microsoft/Phi-3-mini-4k-instruct"),
name="english_llm"
)
p.connect("text_router.en", "english_prompt_builder.query")
p.connect("text_router.de", "german_prompt_builder.query")
p.connect("english_prompt_builder.prompt", "english_llm.prompt")
p.connect("german_prompt_builder.prompt", "german_llm.prompt")
# English Example
print(p.run({"text_router": {"text": "What is the capital of Germany?"}}))
# German Example
print(p.run({"text_router": {"text": "Was ist die Hauptstadt von Deutschland?"}}))
```
"""
def __init__(
self,
model: str,
labels: Optional[List[str]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initializes the TransformersTextRouter.
:param model: The name or path of a Hugging Face model for text classification.
:param labels: The list of labels that the model has been trained to predict. If not provided, the labels
are fetched from the model configuration file hosted on the HuggingFace Hub using
`transformers.AutoConfig.from_pretrained`.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The API token used to download private models from Hugging Face.
If `token` is set to `True`, the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used.
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text classification.
"""
torch_and_transformers_import.check()
self.token = token
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
model=model,
task="text-classification",
supported_tasks=["text-classification"],
device=device,
token=token,
)
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
if labels is None:
config = AutoConfig.from_pretrained(
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
)
self.labels = list(config.label2id.keys())
else:
self.labels = labels
component.set_output_types(self, **{label: str for label in self.labels})
self.pipeline = None
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
return {"model": self.huggingface_pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
def warm_up(self):
"""
Initializes the component.
"""
if self.pipeline is None:
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
# Verify labels from the model configuration file match provided labels
labels = set(self.pipeline.model.config.label2id.keys())
if set(self.labels) != labels:
raise ValueError(
f"The provided labels do not match the labels in the model configuration file. "
f"Provided labels: {self.labels}. Model labels: {labels}"
)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
serialization_dict = default_to_dict(
self,
labels=self.labels,
model=self.huggingface_pipeline_kwargs["model"],
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
huggingface_pipeline_kwargs.pop("token", None)
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TransformersTextRouter":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
return default_from_dict(cls, data)
@component.output_types(documents=Dict[str, str])
def run(self, text: str):
"""
Run the TransformersTextRouter.
This method routes the text to one of the different edges based on which label it has been categorized into.
:param text: A str to route to one of the different edges.
:returns:
A dictionary with the label as key and the text as value.
:raises TypeError:
If the input is not a str.
:raises RuntimeError:
If the pipeline has not been loaded because warm_up() was not called before.
"""
if self.pipeline is None:
raise RuntimeError(
"The component TextTransformersRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)
if not isinstance(text, str):
raise TypeError("TransformersTextRouter expects a str as input.")
prediction = self.pipeline([text], return_all_scores=False, function_to_apply="none")
label = prediction[0]["label"]
return {label: text}

View File

@ -119,7 +119,7 @@ class TransformersZeroShotTextRouter:
:param device: The device on which the model is loaded. If `None`, the default device is automatically :param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The API token used to download private models from Hugging Face. :param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, the token generated when running If `token` is set to `True`, the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used. `transformers-cli login` (stored in ~/.huggingface) is used.
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for zero shot text classification. Hugging Face pipeline for zero shot text classification.
@ -205,11 +205,11 @@ class TransformersZeroShotTextRouter:
:raises TypeError: :raises TypeError:
If the input is not a str. If the input is not a str.
:raises RuntimeError: :raises RuntimeError:
If the pipeline has not been loaded. If the pipeline has not been loaded because warm_up() was not called before.
""" """
if self.pipeline is None: if self.pipeline is None:
raise RuntimeError( raise RuntimeError(
"The zero-shot classification pipeline has not been loaded. Please call warm_up() before running." "The component TransformersZeroShotTextRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
) )
if not isinstance(text, str): if not isinstance(text, str):

View File

@ -0,0 +1,4 @@
---
features:
- |
Introduces the TransformersTextRouter! This component uses a transformers text classification pipeline to route text inputs onto different output connections based on the labels of the chosen text classification model.

View File

@ -0,0 +1,180 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch, MagicMock
import pytest
from haystack.components.routers.transformers_text_router import TransformersTextRouter
from haystack.utils import ComponentDevice, Secret
class TestTransformersTextRouter:
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
def test_to_dict(self, mock_auto_config_from_pretrained):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
router_dict = router.to_dict()
assert router_dict == {
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
"init_parameters": {
"labels": ["en", "de"],
"model": "papluca/xlm-roberta-base-language-detection",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "papluca/xlm-roberta-base-language-detection",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "text-classification",
},
},
}
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
def test_to_dict_with_cpu_device(self, mock_auto_config_from_pretrained):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
router = TransformersTextRouter(
model="papluca/xlm-roberta-base-language-detection", device=ComponentDevice.from_str("cpu")
)
router_dict = router.to_dict()
assert router_dict == {
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
"init_parameters": {
"labels": ["en", "de"],
"model": "papluca/xlm-roberta-base-language-detection",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "papluca/xlm-roberta-base-language-detection",
"device": ComponentDevice.from_str("cpu").to_hf(),
"task": "text-classification",
},
},
}
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
monkeypatch.delenv("HF_API_TOKEN", raising=False)
data = {
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
"init_parameters": {
"model": "papluca/xlm-roberta-base-language-detection",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "papluca/xlm-roberta-base-language-detection",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
},
},
}
component = TransformersTextRouter.from_dict(data)
assert component.labels == ["en", "de"]
assert component.pipeline is None
assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"})
assert component.huggingface_pipeline_kwargs == {
"model": "papluca/xlm-roberta-base-language-detection",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "text-classification",
"token": None,
}
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
def test_from_dict_with_cpu_device(self, mock_auto_config_from_pretrained, monkeypatch):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
monkeypatch.delenv("HF_API_TOKEN", raising=False)
data = {
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
"init_parameters": {
"model": "papluca/xlm-roberta-base-language-detection",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "papluca/xlm-roberta-base-language-detection",
"device": ComponentDevice.from_str("cpu").to_hf(),
"task": "zero-shot-classification",
},
},
}
component = TransformersTextRouter.from_dict(data)
assert component.labels == ["en", "de"]
assert component.pipeline is None
assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"})
assert component.huggingface_pipeline_kwargs == {
"model": "papluca/xlm-roberta-base-language-detection",
"device": ComponentDevice.from_str("cpu").to_hf(),
"task": "text-classification",
"token": None,
}
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
@patch("haystack.components.routers.transformers_text_router.pipeline")
def test_warm_up(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1})))
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
router.warm_up()
assert router.pipeline is not None
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
def test_run_fails_without_warm_up(self, mock_auto_config_from_pretrained):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
with pytest.raises(RuntimeError):
router.run(text="test")
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
@patch("haystack.components.routers.transformers_text_router.pipeline")
def test_run_fails_with_non_string_input(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1})))
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
router.warm_up()
with pytest.raises(TypeError):
router.run(text=["wrong_input"])
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
@patch("haystack.components.routers.transformers_text_router.pipeline")
def test_run_unit(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
hf_pipeline_mock.return_value = [{"label": "en", "score": 0.9}]
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
router.pipeline = hf_pipeline_mock
out = router.run("What is the color of the sky?")
assert router.pipeline is not None
assert out == {"en": "What is the color of the sky?"}
@pytest.mark.integration
def test_run(self):
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
router.warm_up()
out = router.run("What is the color of the sky?")
assert set(router.labels) == {
"ar",
"bg",
"de",
"el",
"en",
"es",
"fr",
"hi",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"sw",
"th",
"tr",
"ur",
"vi",
"zh",
}
assert router.pipeline is not None
assert out == {"en": "What is the color of the sky?"}
@pytest.mark.integration
def test_wrong_labels(self):
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection", labels=["en", "de"])
with pytest.raises(ValueError):
router.warm_up()

View File

@ -58,13 +58,13 @@ class TestTransformersZeroShotTextRouter:
router.warm_up() router.warm_up()
assert router.pipeline is not None assert router.pipeline is not None
def test_run_error(self): def test_run_fails_without_warm_up(self):
router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router = TransformersZeroShotTextRouter(labels=["query", "passage"])
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
router.run(text="test") router.run(text="test")
@patch("haystack.components.routers.zero_shot_text_router.pipeline") @patch("haystack.components.routers.zero_shot_text_router.pipeline")
def test_run_not_str_error(self, hf_pipeline_mock): def test_run_fails_with_non_string_input(self, hf_pipeline_mock):
hf_pipeline_mock.return_value = " " hf_pipeline_mock.return_value = " "
router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router = TransformersZeroShotTextRouter(labels=["query", "passage"])
router.warm_up() router.warm_up()