mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-14 19:47:39 +00:00
feat: Add TransformersTextRouter
component (#7801)
* First pass at adding TransformerTextRouter * Fix tests * Add release notes * Add optional labels param * Add verification in the warm_up * Fix tests * Add labels to to_dict * Feedback from review * Add component to docs * Added extra tests
This commit is contained in:
parent
e6b8b7529b
commit
d815c78198
@ -7,6 +7,7 @@ loaders:
|
||||
"file_type_router",
|
||||
"metadata_router",
|
||||
"text_language_router",
|
||||
"transformers_text_router",
|
||||
"zero_shot_text_router",
|
||||
]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
|
@ -6,6 +6,7 @@ from haystack.components.routers.conditional_router import ConditionalRouter
|
||||
from haystack.components.routers.file_type_router import FileTypeRouter
|
||||
from haystack.components.routers.metadata_router import MetadataRouter
|
||||
from haystack.components.routers.text_language_router import TextLanguageRouter
|
||||
from haystack.components.routers.transformers_text_router import TransformersTextRouter
|
||||
from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter
|
||||
|
||||
__all__ = [
|
||||
@ -14,4 +15,5 @@ __all__ = [
|
||||
"TextLanguageRouter",
|
||||
"ConditionalRouter",
|
||||
"TransformersZeroShotTextRouter",
|
||||
"TransformersTextRouter",
|
||||
]
|
||||
|
209
haystack/components/routers/transformers_text_router.py
Normal file
209
haystack/components/routers/transformers_text_router.py
Normal file
@ -0,0 +1,209 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
|
||||
from transformers import AutoConfig, pipeline
|
||||
|
||||
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
|
||||
deserialize_hf_model_kwargs,
|
||||
resolve_hf_pipeline_kwargs,
|
||||
serialize_hf_model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@component
|
||||
class TransformersTextRouter:
|
||||
"""
|
||||
Routes a text input onto different output connections depending on which label it has been categorized into.
|
||||
|
||||
This is useful for routing queries to different models in a pipeline depending on their categorization.
|
||||
The set of labels to be used for categorization are provided by the selected model.
|
||||
|
||||
Example usage in a query pipeline that routes english queries to a text generator optimized for english text and
|
||||
german queries to a text generator optimized for german text.
|
||||
|
||||
```python
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.components.routers import TransformersTextRouter
|
||||
from haystack.components.builders import PromptBuilder
|
||||
from haystack.components.generators import HuggingFaceLocalGenerator
|
||||
|
||||
p = Pipeline()
|
||||
p.add_component(
|
||||
instance=TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection"),
|
||||
name="text_router"
|
||||
)
|
||||
p.add_component(
|
||||
instance=PromptBuilder(template="Answer the question: {{query}}\nAnswer:"),
|
||||
name="english_prompt_builder"
|
||||
)
|
||||
p.add_component(
|
||||
instance=PromptBuilder(template="Beantworte die Frage: {{query}}\nAntwort:"),
|
||||
name="german_prompt_builder"
|
||||
)
|
||||
|
||||
p.add_component(
|
||||
instance=HuggingFaceLocalGenerator(model="DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1"),
|
||||
name="german_llm"
|
||||
)
|
||||
p.add_component(
|
||||
instance=HuggingFaceLocalGenerator(model="microsoft/Phi-3-mini-4k-instruct"),
|
||||
name="english_llm"
|
||||
)
|
||||
|
||||
p.connect("text_router.en", "english_prompt_builder.query")
|
||||
p.connect("text_router.de", "german_prompt_builder.query")
|
||||
p.connect("english_prompt_builder.prompt", "english_llm.prompt")
|
||||
p.connect("german_prompt_builder.prompt", "german_llm.prompt")
|
||||
|
||||
# English Example
|
||||
print(p.run({"text_router": {"text": "What is the capital of Germany?"}}))
|
||||
|
||||
# German Example
|
||||
print(p.run({"text_router": {"text": "Was ist die Hauptstadt von Deutschland?"}}))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
labels: Optional[List[str]] = None,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the TransformersTextRouter.
|
||||
|
||||
:param model: The name or path of a Hugging Face model for text classification.
|
||||
:param labels: The list of labels that the model has been trained to predict. If not provided, the labels
|
||||
are fetched from the model configuration file hosted on the HuggingFace Hub using
|
||||
`transformers.AutoConfig.from_pretrained`.
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token: The API token used to download private models from Hugging Face.
|
||||
If `token` is set to `True`, the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) is used.
|
||||
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
|
||||
Hugging Face pipeline for text classification.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
self.token = token
|
||||
|
||||
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
|
||||
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
|
||||
model=model,
|
||||
task="text-classification",
|
||||
supported_tasks=["text-classification"],
|
||||
device=device,
|
||||
token=token,
|
||||
)
|
||||
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
||||
|
||||
if labels is None:
|
||||
config = AutoConfig.from_pretrained(
|
||||
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
|
||||
)
|
||||
self.labels = list(config.label2id.keys())
|
||||
else:
|
||||
self.labels = labels
|
||||
component.set_output_types(self, **{label: str for label in self.labels})
|
||||
|
||||
self.pipeline = None
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
|
||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
||||
|
||||
# Verify labels from the model configuration file match provided labels
|
||||
labels = set(self.pipeline.model.config.label2id.keys())
|
||||
if set(self.labels) != labels:
|
||||
raise ValueError(
|
||||
f"The provided labels do not match the labels in the model configuration file. "
|
||||
f"Provided labels: {self.labels}. Model labels: {labels}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
serialization_dict = default_to_dict(
|
||||
self,
|
||||
labels=self.labels,
|
||||
model=self.huggingface_pipeline_kwargs["model"],
|
||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
||||
huggingface_pipeline_kwargs.pop("token", None)
|
||||
|
||||
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
||||
return serialization_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TransformersTextRouter":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=Dict[str, str])
|
||||
def run(self, text: str):
|
||||
"""
|
||||
Run the TransformersTextRouter.
|
||||
|
||||
This method routes the text to one of the different edges based on which label it has been categorized into.
|
||||
|
||||
:param text: A str to route to one of the different edges.
|
||||
:returns:
|
||||
A dictionary with the label as key and the text as value.
|
||||
|
||||
:raises TypeError:
|
||||
If the input is not a str.
|
||||
:raises RuntimeError:
|
||||
If the pipeline has not been loaded because warm_up() was not called before.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError(
|
||||
"The component TextTransformersRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
if not isinstance(text, str):
|
||||
raise TypeError("TransformersTextRouter expects a str as input.")
|
||||
|
||||
prediction = self.pipeline([text], return_all_scores=False, function_to_apply="none")
|
||||
label = prediction[0]["label"]
|
||||
return {label: text}
|
@ -119,7 +119,7 @@ class TransformersZeroShotTextRouter:
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, the token generated when running
|
||||
If `token` is set to `True`, the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) is used.
|
||||
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
|
||||
Hugging Face pipeline for zero shot text classification.
|
||||
@ -205,11 +205,11 @@ class TransformersZeroShotTextRouter:
|
||||
:raises TypeError:
|
||||
If the input is not a str.
|
||||
:raises RuntimeError:
|
||||
If the pipeline has not been loaded.
|
||||
If the pipeline has not been loaded because warm_up() was not called before.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError(
|
||||
"The zero-shot classification pipeline has not been loaded. Please call warm_up() before running."
|
||||
"The component TransformersZeroShotTextRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
if not isinstance(text, str):
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduces the TransformersTextRouter! This component uses a transformers text classification pipeline to route text inputs onto different output connections based on the labels of the chosen text classification model.
|
180
test/components/routers/test_transformers_text_router.py
Normal file
180
test/components/routers/test_transformers_text_router.py
Normal file
@ -0,0 +1,180 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.routers.transformers_text_router import TransformersTextRouter
|
||||
from haystack.utils import ComponentDevice, Secret
|
||||
|
||||
|
||||
class TestTransformersTextRouter:
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
def test_to_dict(self, mock_auto_config_from_pretrained):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
||||
router_dict = router.to_dict()
|
||||
assert router_dict == {
|
||||
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
||||
"init_parameters": {
|
||||
"labels": ["en", "de"],
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "text-classification",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
def test_to_dict_with_cpu_device(self, mock_auto_config_from_pretrained):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
router = TransformersTextRouter(
|
||||
model="papluca/xlm-roberta-base-language-detection", device=ComponentDevice.from_str("cpu")
|
||||
)
|
||||
router_dict = router.to_dict()
|
||||
assert router_dict == {
|
||||
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
||||
"init_parameters": {
|
||||
"labels": ["en", "de"],
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"device": ComponentDevice.from_str("cpu").to_hf(),
|
||||
"task": "text-classification",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
||||
"init_parameters": {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "zero-shot-classification",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
component = TransformersTextRouter.from_dict(data)
|
||||
assert component.labels == ["en", "de"]
|
||||
assert component.pipeline is None
|
||||
assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"})
|
||||
assert component.huggingface_pipeline_kwargs == {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "text-classification",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
def test_from_dict_with_cpu_device(self, mock_auto_config_from_pretrained, monkeypatch):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
||||
"init_parameters": {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"device": ComponentDevice.from_str("cpu").to_hf(),
|
||||
"task": "zero-shot-classification",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
component = TransformersTextRouter.from_dict(data)
|
||||
assert component.labels == ["en", "de"]
|
||||
assert component.pipeline is None
|
||||
assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"})
|
||||
assert component.huggingface_pipeline_kwargs == {
|
||||
"model": "papluca/xlm-roberta-base-language-detection",
|
||||
"device": ComponentDevice.from_str("cpu").to_hf(),
|
||||
"task": "text-classification",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
@patch("haystack.components.routers.transformers_text_router.pipeline")
|
||||
def test_warm_up(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
|
||||
hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1})))
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
||||
router.warm_up()
|
||||
assert router.pipeline is not None
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
def test_run_fails_without_warm_up(self, mock_auto_config_from_pretrained):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
||||
with pytest.raises(RuntimeError):
|
||||
router.run(text="test")
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
@patch("haystack.components.routers.transformers_text_router.pipeline")
|
||||
def test_run_fails_with_non_string_input(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1})))
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
||||
router.warm_up()
|
||||
with pytest.raises(TypeError):
|
||||
router.run(text=["wrong_input"])
|
||||
|
||||
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
||||
@patch("haystack.components.routers.transformers_text_router.pipeline")
|
||||
def test_run_unit(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
|
||||
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
||||
hf_pipeline_mock.return_value = [{"label": "en", "score": 0.9}]
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
||||
router.pipeline = hf_pipeline_mock
|
||||
out = router.run("What is the color of the sky?")
|
||||
assert router.pipeline is not None
|
||||
assert out == {"en": "What is the color of the sky?"}
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run(self):
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
||||
router.warm_up()
|
||||
out = router.run("What is the color of the sky?")
|
||||
assert set(router.labels) == {
|
||||
"ar",
|
||||
"bg",
|
||||
"de",
|
||||
"el",
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"hi",
|
||||
"it",
|
||||
"ja",
|
||||
"nl",
|
||||
"pl",
|
||||
"pt",
|
||||
"ru",
|
||||
"sw",
|
||||
"th",
|
||||
"tr",
|
||||
"ur",
|
||||
"vi",
|
||||
"zh",
|
||||
}
|
||||
assert router.pipeline is not None
|
||||
assert out == {"en": "What is the color of the sky?"}
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_wrong_labels(self):
|
||||
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection", labels=["en", "de"])
|
||||
with pytest.raises(ValueError):
|
||||
router.warm_up()
|
@ -58,13 +58,13 @@ class TestTransformersZeroShotTextRouter:
|
||||
router.warm_up()
|
||||
assert router.pipeline is not None
|
||||
|
||||
def test_run_error(self):
|
||||
def test_run_fails_without_warm_up(self):
|
||||
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
||||
with pytest.raises(RuntimeError):
|
||||
router.run(text="test")
|
||||
|
||||
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
|
||||
def test_run_not_str_error(self, hf_pipeline_mock):
|
||||
def test_run_fails_with_non_string_input(self, hf_pipeline_mock):
|
||||
hf_pipeline_mock.return_value = " "
|
||||
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
||||
router.warm_up()
|
||||
|
Loading…
x
Reference in New Issue
Block a user