haystack/test/components/routers/test_zero_shot_text_router.py
Sebastian Husch Lee 85c1e39fab
feat: Add Zero Shot Transformers Text Router (#7018)
* Starting to add TransformersTextRouter

* First pass at a TextRouter based off of the zero shot classification model on HuggingFace

* Fix pylint

* Remove unneeded imports

* Update documentation example

* Update error message strings

* Starting to add unit tests

* Release notes

* Fix pylint

* Add tests for to dict and from dict

* Update patches in tests to be correct with respect to changes

* Doc strings and fixes

* Adding more tests

* Change name

* Adding to init

* Use Haystack logger

* Beef up docstrings

* Make example runnable

* Rename to huggingface_pipeline_kwargs

* Fix example
2024-03-15 13:56:07 +01:00

88 lines
3.8 KiB
Python

from unittest.mock import patch
import pytest
from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter
from haystack.utils import ComponentDevice, Secret
class TestTransformersZeroShotTextRouter:
def test_to_dict(self):
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
router_dict = router.to_dict()
assert router_dict == {
"type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter",
"init_parameters": {
"labels": ["query", "passage"],
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
},
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter",
"init_parameters": {
"labels": ["query", "passage"],
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
},
},
}
component = TransformersZeroShotTextRouter.from_dict(data)
assert component.labels == ["query", "passage"]
assert component.pipeline is None
assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"})
assert component.huggingface_pipeline_kwargs == {
"model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
"token": None,
}
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
def test_warm_up(self, hf_pipeline_mock):
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
router.warm_up()
assert router.pipeline is not None
def test_run_error(self):
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
with pytest.raises(RuntimeError):
router.run(text="test")
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
def test_run_not_str_error(self, hf_pipeline_mock):
hf_pipeline_mock.return_value = " "
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
router.warm_up()
with pytest.raises(TypeError):
router.run(text=["wrong_input"])
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
def test_run_unit(self, hf_pipeline_mock):
hf_pipeline_mock.return_value = [
{"sequence": "What is the color of the sky?", "labels": ["query", "passage"], "scores": [0.9, 0.1]}
]
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
router.pipeline = hf_pipeline_mock
out = router.run("What is the color of the sky?")
assert router.pipeline is not None
assert out == {"query": "What is the color of the sky?"}
@pytest.mark.integration
def test_run(self):
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
router.warm_up()
out = router.run("What is the color of the sky?")
assert router.pipeline is not None
assert out == {"query": "What is the color of the sky?"}