mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-23 17:00:41 +00:00

* Starting to add TransformersTextRouter * First pass at a TextRouter based off of the zero shot classification model on HuggingFace * Fix pylint * Remove unneeded imports * Update documentation example * Update error message strings * Starting to add unit tests * Release notes * Fix pylint * Add tests for to dict and from dict * Update patches in tests to be correct with respect to changes * Doc strings and fixes * Adding more tests * Change name * Adding to init * Use Haystack logger * Beef up docstrings * Make example runnable * Rename to huggingface_pipeline_kwargs * Fix example
88 lines
3.8 KiB
Python
88 lines
3.8 KiB
Python
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter
|
|
from haystack.utils import ComponentDevice, Secret
|
|
|
|
|
|
class TestTransformersZeroShotTextRouter:
|
|
def test_to_dict(self):
|
|
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
|
router_dict = router.to_dict()
|
|
assert router_dict == {
|
|
"type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter",
|
|
"init_parameters": {
|
|
"labels": ["query", "passage"],
|
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter",
|
|
"init_parameters": {
|
|
"labels": ["query", "passage"],
|
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
component = TransformersZeroShotTextRouter.from_dict(data)
|
|
assert component.labels == ["query", "passage"]
|
|
assert component.pipeline is None
|
|
assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"})
|
|
assert component.huggingface_pipeline_kwargs == {
|
|
"model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
"token": None,
|
|
}
|
|
|
|
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
|
|
def test_warm_up(self, hf_pipeline_mock):
|
|
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
|
router.warm_up()
|
|
assert router.pipeline is not None
|
|
|
|
def test_run_error(self):
|
|
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
|
with pytest.raises(RuntimeError):
|
|
router.run(text="test")
|
|
|
|
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
|
|
def test_run_not_str_error(self, hf_pipeline_mock):
|
|
hf_pipeline_mock.return_value = " "
|
|
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
|
router.warm_up()
|
|
with pytest.raises(TypeError):
|
|
router.run(text=["wrong_input"])
|
|
|
|
@patch("haystack.components.routers.zero_shot_text_router.pipeline")
|
|
def test_run_unit(self, hf_pipeline_mock):
|
|
hf_pipeline_mock.return_value = [
|
|
{"sequence": "What is the color of the sky?", "labels": ["query", "passage"], "scores": [0.9, 0.1]}
|
|
]
|
|
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
|
router.pipeline = hf_pipeline_mock
|
|
out = router.run("What is the color of the sky?")
|
|
assert router.pipeline is not None
|
|
assert out == {"query": "What is the color of the sky?"}
|
|
|
|
@pytest.mark.integration
|
|
def test_run(self):
|
|
router = TransformersZeroShotTextRouter(labels=["query", "passage"])
|
|
router.warm_up()
|
|
out = router.run("What is the color of the sky?")
|
|
assert router.pipeline is not None
|
|
assert out == {"query": "What is the color of the sky?"}
|