mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-23 00:42:28 +00:00

* trigger * try pinning sentence transformers * make integr tests run right away * pin transformers instead * older transformers version * rm transformers pin * try ignoring cache * change ubuntu version * try removing token * try again * more HF_API_TOKEN local deletions * restore test priority * rm leftover * more deletions * moreee * more * deletions * restore jobs order
211 lines
10 KiB
Python
211 lines
10 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from haystack.components.routers.transformers_text_router import TransformersTextRouter
|
|
from haystack.utils import ComponentDevice, Secret
|
|
|
|
|
|
class TestTransformersTextRouter:
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
def test_to_dict(self, mock_auto_config_from_pretrained):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
|
router_dict = router.to_dict()
|
|
assert router_dict == {
|
|
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
|
"init_parameters": {
|
|
"labels": ["en", "de"],
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "text-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
def test_to_dict_with_cpu_device(self, mock_auto_config_from_pretrained):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
router = TransformersTextRouter(
|
|
model="papluca/xlm-roberta-base-language-detection", device=ComponentDevice.from_str("cpu")
|
|
)
|
|
router_dict = router.to_dict()
|
|
assert router_dict == {
|
|
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
|
"init_parameters": {
|
|
"labels": ["en", "de"],
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.from_str("cpu").to_hf(),
|
|
"task": "text-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
|
data = {
|
|
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
|
"init_parameters": {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
component = TransformersTextRouter.from_dict(data)
|
|
assert component.labels == ["en", "de"]
|
|
assert component.pipeline is None
|
|
assert component.token == Secret.from_dict(
|
|
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
|
)
|
|
assert component.huggingface_pipeline_kwargs == {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "text-classification",
|
|
"token": None,
|
|
}
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
def test_from_dict_no_default_parameters(self, mock_auto_config_from_pretrained, monkeypatch):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
|
data = {
|
|
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
|
"init_parameters": {"model": "papluca/xlm-roberta-base-language-detection"},
|
|
}
|
|
component = TransformersTextRouter.from_dict(data)
|
|
assert component.labels == ["en", "de"]
|
|
assert component.pipeline is None
|
|
assert component.token == Secret.from_dict(
|
|
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
|
)
|
|
assert component.huggingface_pipeline_kwargs == {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "text-classification",
|
|
"token": None,
|
|
}
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
def test_from_dict_with_cpu_device(self, mock_auto_config_from_pretrained, monkeypatch):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
|
data = {
|
|
"type": "haystack.components.routers.transformers_text_router.TransformersTextRouter",
|
|
"init_parameters": {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.from_str("cpu").to_hf(),
|
|
"task": "zero-shot-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
component = TransformersTextRouter.from_dict(data)
|
|
assert component.labels == ["en", "de"]
|
|
assert component.pipeline is None
|
|
assert component.token == Secret.from_dict(
|
|
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
|
)
|
|
assert component.huggingface_pipeline_kwargs == {
|
|
"model": "papluca/xlm-roberta-base-language-detection",
|
|
"device": ComponentDevice.from_str("cpu").to_hf(),
|
|
"task": "text-classification",
|
|
"token": None,
|
|
}
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
@patch("haystack.components.routers.transformers_text_router.pipeline")
|
|
def test_warm_up(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
|
|
hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1})))
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
|
router.warm_up()
|
|
assert router.pipeline is not None
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
def test_run_fails_without_warm_up(self, mock_auto_config_from_pretrained):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
|
with pytest.raises(RuntimeError):
|
|
router.run(text="test")
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
@patch("haystack.components.routers.transformers_text_router.pipeline")
|
|
def test_run_fails_with_non_string_input(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
hf_pipeline_mock.return_value = MagicMock(model=MagicMock(config=MagicMock(label2id={"en": 0, "de": 1})))
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
|
router.warm_up()
|
|
with pytest.raises(TypeError):
|
|
router.run(text=["wrong_input"])
|
|
|
|
@patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained")
|
|
@patch("haystack.components.routers.transformers_text_router.pipeline")
|
|
def test_run_unit(self, hf_pipeline_mock, mock_auto_config_from_pretrained):
|
|
mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1})
|
|
hf_pipeline_mock.return_value = [{"label": "en", "score": 0.9}]
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
|
router.pipeline = hf_pipeline_mock
|
|
out = router.run("What is the color of the sky?")
|
|
assert router.pipeline is not None
|
|
assert out == {"en": "What is the color of the sky?"}
|
|
|
|
@pytest.mark.integration
|
|
def test_run(self, monkeypatch):
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection")
|
|
router.warm_up()
|
|
out = router.run("What is the color of the sky?")
|
|
assert set(router.labels) == {
|
|
"ar",
|
|
"bg",
|
|
"de",
|
|
"el",
|
|
"en",
|
|
"es",
|
|
"fr",
|
|
"hi",
|
|
"it",
|
|
"ja",
|
|
"nl",
|
|
"pl",
|
|
"pt",
|
|
"ru",
|
|
"sw",
|
|
"th",
|
|
"tr",
|
|
"ur",
|
|
"vi",
|
|
"zh",
|
|
}
|
|
assert router.pipeline is not None
|
|
assert out == {"en": "What is the color of the sky?"}
|
|
|
|
@pytest.mark.integration
|
|
def test_wrong_labels(self, monkeypatch):
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
|
|
router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection", labels=["en", "de"])
|
|
with pytest.raises(ValueError):
|
|
router.warm_up()
|