diff --git a/docs/pydoc/config/routers_api.yml b/docs/pydoc/config/routers_api.yml index e9430e289..126d08d7e 100644 --- a/docs/pydoc/config/routers_api.yml +++ b/docs/pydoc/config/routers_api.yml @@ -7,6 +7,7 @@ loaders: "file_type_router", "metadata_router", "text_language_router", + "transformers_text_router", "zero_shot_text_router", ] ignore_when_discovered: ["__init__"] diff --git a/haystack/components/routers/__init__.py b/haystack/components/routers/__init__.py index a74118e0a..f22d69917 100644 --- a/haystack/components/routers/__init__.py +++ b/haystack/components/routers/__init__.py @@ -6,6 +6,7 @@ from haystack.components.routers.conditional_router import ConditionalRouter from haystack.components.routers.file_type_router import FileTypeRouter from haystack.components.routers.metadata_router import MetadataRouter from haystack.components.routers.text_language_router import TextLanguageRouter +from haystack.components.routers.transformers_text_router import TransformersTextRouter from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter __all__ = [ @@ -14,4 +15,5 @@ __all__ = [ "TextLanguageRouter", "ConditionalRouter", "TransformersZeroShotTextRouter", + "TransformersTextRouter", ] diff --git a/haystack/components/routers/transformers_text_router.py b/haystack/components/routers/transformers_text_router.py new file mode 100644 index 000000000..066e9cc68 --- /dev/null +++ b/haystack/components/routers/transformers_text_router.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: + from transformers import AutoConfig, pipeline + + from haystack.utils.hf import ( # pylint: disable=ungrouped-imports + deserialize_hf_model_kwargs, + resolve_hf_pipeline_kwargs, + serialize_hf_model_kwargs, + ) + + +@component +class TransformersTextRouter: + """ + Routes a text input onto different output connections depending on which label it has been categorized into. + + This is useful for routing queries to different models in a pipeline depending on their categorization. + The set of labels to be used for categorization are provided by the selected model. + + Example usage in a query pipeline that routes english queries to a text generator optimized for english text and + german queries to a text generator optimized for german text. + + ```python + from haystack.core.pipeline import Pipeline + from haystack.components.routers import TransformersTextRouter + from haystack.components.builders import PromptBuilder + from haystack.components.generators import HuggingFaceLocalGenerator + + p = Pipeline() + p.add_component( + instance=TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection"), + name="text_router" + ) + p.add_component( + instance=PromptBuilder(template="Answer the question: {{query}}\nAnswer:"), + name="english_prompt_builder" + ) + p.add_component( + instance=PromptBuilder(template="Beantworte die Frage: {{query}}\nAntwort:"), + name="german_prompt_builder" + ) + + p.add_component( + instance=HuggingFaceLocalGenerator(model="DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1"), + name="german_llm" + ) + p.add_component( + instance=HuggingFaceLocalGenerator(model="microsoft/Phi-3-mini-4k-instruct"), + name="english_llm" + ) + + p.connect("text_router.en", "english_prompt_builder.query") + p.connect("text_router.de", "german_prompt_builder.query") + p.connect("english_prompt_builder.prompt", "english_llm.prompt") + p.connect("german_prompt_builder.prompt", "german_llm.prompt") + + # English Example + print(p.run({"text_router": {"text": "What is the capital of Germany?"}})) + + # German Example + print(p.run({"text_router": {"text": "Was ist die Hauptstadt von Deutschland?"}})) + ``` + """ + + def __init__( + self, + model: str, + labels: Optional[List[str]] = None, + device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Initializes the TransformersTextRouter. + + :param model: The name or path of a Hugging Face model for text classification. + :param labels: The list of labels that the model has been trained to predict. If not provided, the labels + are fetched from the model configuration file hosted on the HuggingFace Hub using + `transformers.AutoConfig.from_pretrained`. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. + :param token: The API token used to download private models from Hugging Face. + If `token` is set to `True`, the token generated when running + `transformers-cli login` (stored in ~/.huggingface) is used. + :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the + Hugging Face pipeline for text classification. + """ + torch_and_transformers_import.check() + + self.token = token + + huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, + model=model, + task="text-classification", + supported_tasks=["text-classification"], + device=device, + token=token, + ) + self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs + + if labels is None: + config = AutoConfig.from_pretrained( + huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"] + ) + self.labels = list(config.label2id.keys()) + else: + self.labels = labels + component.set_output_types(self, **{label: str for label in self.labels}) + + self.pipeline = None + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + if isinstance(self.huggingface_pipeline_kwargs["model"], str): + return {"model": self.huggingface_pipeline_kwargs["model"]} + return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} + + def warm_up(self): + """ + Initializes the component. + """ + if self.pipeline is None: + self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) + + # Verify labels from the model configuration file match provided labels + labels = set(self.pipeline.model.config.label2id.keys()) + if set(self.labels) != labels: + raise ValueError( + f"The provided labels do not match the labels in the model configuration file. " + f"Provided labels: {self.labels}. Model labels: {labels}" + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + serialization_dict = default_to_dict( + self, + labels=self.labels, + model=self.huggingface_pipeline_kwargs["model"], + huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, + token=self.token.to_dict() if self.token else None, + ) + + huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] + huggingface_pipeline_kwargs.pop("token", None) + + serialize_hf_model_kwargs(huggingface_pipeline_kwargs) + return serialization_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TransformersTextRouter": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) + return default_from_dict(cls, data) + + @component.output_types(documents=Dict[str, str]) + def run(self, text: str): + """ + Run the TransformersTextRouter. + + This method routes the text to one of the different edges based on which label it has been categorized into. + + :param text: A str to route to one of the different edges. + :returns: + A dictionary with the label as key and the text as value. + + :raises TypeError: + If the input is not a str. + :raises RuntimeError: + If the pipeline has not been loaded because warm_up() was not called before. + """ + if self.pipeline is None: + raise RuntimeError( + "The component TextTransformersRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'." + ) + + if not isinstance(text, str): + raise TypeError("TransformersTextRouter expects a str as input.") + + prediction = self.pipeline([text], return_all_scores=False, function_to_apply="none") + label = prediction[0]["label"] + return {label: text} diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 7c3a34666..b94a4c6b0 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -119,7 +119,7 @@ class TransformersZeroShotTextRouter: :param device: The device on which the model is loaded. If `None`, the default device is automatically selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. :param token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, the token generated when running + If `token` is set to `True`, the token generated when running `transformers-cli login` (stored in ~/.huggingface) is used. :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the Hugging Face pipeline for zero shot text classification. @@ -205,11 +205,11 @@ class TransformersZeroShotTextRouter: :raises TypeError: If the input is not a str. :raises RuntimeError: - If the pipeline has not been loaded. + If the pipeline has not been loaded because warm_up() was not called before. """ if self.pipeline is None: raise RuntimeError( - "The zero-shot classification pipeline has not been loaded. Please call warm_up() before running." + "The component TransformersZeroShotTextRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'." ) if not isinstance(text, str): diff --git a/releasenotes/notes/add-transformers-text-router-5542b9d13a3c91c9.yaml b/releasenotes/notes/add-transformers-text-router-5542b9d13a3c91c9.yaml new file mode 100644 index 000000000..bf0d7c8e8 --- /dev/null +++ b/releasenotes/notes/add-transformers-text-router-5542b9d13a3c91c9.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Introduces the TransformersTextRouter! This component uses a transformers text classification pipeline to route text inputs onto different output connections based on the labels of the chosen text classification model. diff --git a/test/components/routers/test_transformers_text_router.py b/test/components/routers/test_transformers_text_router.py new file mode 100644 index 000000000..5a2185224 --- /dev/null +++ b/test/components/routers/test_transformers_text_router.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch, MagicMock + +import pytest + +from haystack.components.routers.transformers_text_router import TransformersTextRouter +from haystack.utils import ComponentDevice, Secret + + +class TestTransformersTextRouter: + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + def test_to_dict(self, mock_auto_config_from_pretrained): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") + router_dict = router.to_dict() + assert router_dict == { + "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", + "init_parameters": { + "labels": ["en", "de"], + "model": "papluca/xlm-roberta-base-language-detection", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "huggingface_pipeline_kwargs": { + "model": "papluca/xlm-roberta-base-language-detection", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "text-classification", + }, + }, + } + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + def test_to_dict_with_cpu_device(self, mock_auto_config_from_pretrained): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + router = TransformersTextRouter( + model="papluca/xlm-roberta-base-language-detection", device=ComponentDevice.from_str("cpu") + ) + router_dict = router.to_dict() + assert router_dict == { + "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", + "init_parameters": { + "labels": ["en", "de"], + "model": "papluca/xlm-roberta-base-language-detection", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "huggingface_pipeline_kwargs": { + "model": "papluca/xlm-roberta-base-language-detection", + "device": ComponentDevice.from_str("cpu").to_hf(), + "task": "text-classification", + }, + }, + } + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + monkeypatch.delenv("HF_API_TOKEN", raising=False) + data = { + "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", + "init_parameters": { + "model": "papluca/xlm-roberta-base-language-detection", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "huggingface_pipeline_kwargs": { + "model": "papluca/xlm-roberta-base-language-detection", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + }, + }, + } + + component = TransformersTextRouter.from_dict(data) + assert component.labels == ["en", "de"] + assert component.pipeline is None + assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}) + assert component.huggingface_pipeline_kwargs == { + "model": "papluca/xlm-roberta-base-language-detection", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "text-classification", + "token": None, + } + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + def test_from_dict_with_cpu_device(self, mock_auto_config_from_pretrained, monkeypatch): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + monkeypatch.delenv("HF_API_TOKEN", raising=False) + data = { + "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", + "init_parameters": { + "model": "papluca/xlm-roberta-base-language-detection", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "huggingface_pipeline_kwargs": { + "model": "papluca/xlm-roberta-base-language-detection", + "device": ComponentDevice.from_str("cpu").to_hf(), + "task": "zero-shot-classification", + }, + }, + } + + component = TransformersTextRouter.from_dict(data) + assert component.labels == ["en", "de"] + assert component.pipeline is None + assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}) + assert component.huggingface_pipeline_kwargs == { + "model": "papluca/xlm-roberta-base-language-detection", + "device": ComponentDevice.from_str("cpu").to_hf(), + "task": "text-classification", + "token": None, + } + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + @patch("haystack.components.routers.transformers_text_router.pipeline") + def test_warm_up(self, hf_pipeline_mock, mock_auto_config_from_pretrained): + hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1}))) + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") + router.warm_up() + assert router.pipeline is not None + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + def test_run_fails_without_warm_up(self, mock_auto_config_from_pretrained): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") + with pytest.raises(RuntimeError): + router.run(text="test") + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + @patch("haystack.components.routers.transformers_text_router.pipeline") + def test_run_fails_with_non_string_input(self, hf_pipeline_mock, mock_auto_config_from_pretrained): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1}))) + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") + router.warm_up() + with pytest.raises(TypeError): + router.run(text=["wrong_input"]) + + @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") + @patch("haystack.components.routers.transformers_text_router.pipeline") + def test_run_unit(self, hf_pipeline_mock, mock_auto_config_from_pretrained): + mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + hf_pipeline_mock.return_value = [{"label": "en", "score": 0.9}] + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") + router.pipeline = hf_pipeline_mock + out = router.run("What is the color of the sky?") + assert router.pipeline is not None + assert out == {"en": "What is the color of the sky?"} + + @pytest.mark.integration + def test_run(self): + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") + router.warm_up() + out = router.run("What is the color of the sky?") + assert set(router.labels) == { + "ar", + "bg", + "de", + "el", + "en", + "es", + "fr", + "hi", + "it", + "ja", + "nl", + "pl", + "pt", + "ru", + "sw", + "th", + "tr", + "ur", + "vi", + "zh", + } + assert router.pipeline is not None + assert out == {"en": "What is the color of the sky?"} + + @pytest.mark.integration + def test_wrong_labels(self): + router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection", labels=["en", "de"]) + with pytest.raises(ValueError): + router.warm_up() diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 2ecc883cd..b1aabc394 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -58,13 +58,13 @@ class TestTransformersZeroShotTextRouter: router.warm_up() assert router.pipeline is not None - def test_run_error(self): + def test_run_fails_without_warm_up(self): router = TransformersZeroShotTextRouter(labels=["query", "passage"]) with pytest.raises(RuntimeError): router.run(text="test") @patch("haystack.components.routers.zero_shot_text_router.pipeline") - def test_run_not_str_error(self, hf_pipeline_mock): + def test_run_fails_with_non_string_input(self, hf_pipeline_mock): hf_pipeline_mock.return_value = " " router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router.warm_up()