From 93491c64be526d13c8d9443ff6c63f52bace5d99 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 18 Dec 2025 11:52:56 +0100 Subject: [PATCH] test: refactor test_run_fails_without_warm_up tests (#10267) --- .../test_zero_shot_document_classifier.py | 12 +++++++++--- .../routers/test_transformers_text_router.py | 10 +++++++--- .../components/routers/test_zero_shot_text_router.py | 10 +++++++--- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/test/components/classifiers/test_zero_shot_document_classifier.py b/test/components/classifiers/test_zero_shot_document_classifier.py index 677ca5169..4167d7dc7 100644 --- a/test/components/classifiers/test_zero_shot_document_classifier.py +++ b/test/components/classifiers/test_zero_shot_document_classifier.py @@ -100,13 +100,19 @@ class TestTransformersZeroShotDocumentClassifier: component.warm_up() assert component.pipeline is not None - def test_run_fails_without_warm_up(self): + @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") + @patch.object(TransformersZeroShotDocumentClassifier, "warm_up") + def test_run_calls_warm_up(self, warm_up_mock, hf_pipeline_mock): + hf_pipeline_mock.return_value = [ + {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]} + ] component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) + warm_up_mock.side_effect = lambda: setattr(component, "pipeline", hf_pipeline_mock) positive_documents = [Document(content="That's good. I like it.")] - with pytest.raises(RuntimeError): - component.run(documents=positive_documents) + component.run(documents=positive_documents) + warm_up_mock.assert_called_once() @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") def test_run_fails_with_non_document_input(self, hf_pipeline_mock): diff --git a/test/components/routers/test_transformers_text_router.py b/test/components/routers/test_transformers_text_router.py index 46a7c6cf4..2b53af098 100644 --- a/test/components/routers/test_transformers_text_router.py +++ b/test/components/routers/test_transformers_text_router.py @@ -145,11 +145,15 @@ class TestTransformersTextRouter: assert router.pipeline is not None @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") - def test_run_fails_without_warm_up(self, mock_auto_config_from_pretrained): + @patch("haystack.components.routers.transformers_text_router.pipeline") + @patch.object(TransformersTextRouter, "warm_up") + def test_run_calls_warm_up(self, warm_up_mock, hf_pipeline_mock, mock_auto_config_from_pretrained): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) + hf_pipeline_mock.return_value = [{"label": "en", "score": 0.9}] router = TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection") - with pytest.raises(RuntimeError): - router.run(text="test") + warm_up_mock.side_effect = lambda: setattr(router, "pipeline", hf_pipeline_mock) + router.run(text="test") + warm_up_mock.assert_called_once() @patch("haystack.components.routers.transformers_text_router.AutoConfig.from_pretrained") @patch("haystack.components.routers.transformers_text_router.pipeline") diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 33f8e6ec2..c56b4c42e 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -82,10 +82,14 @@ class TestTransformersZeroShotTextRouter: router.warm_up() assert router.pipeline is not None - def test_run_fails_without_warm_up(self): + @patch("haystack.components.routers.zero_shot_text_router.pipeline") + @patch.object(TransformersZeroShotTextRouter, "warm_up") + def test_run_calls_warm_up(self, warm_up_mock, hf_pipeline_mock): + hf_pipeline_mock.return_value = [{"sequence": "test", "labels": ["query", "passage"], "scores": [0.9, 0.1]}] router = TransformersZeroShotTextRouter(labels=["query", "passage"]) - with pytest.raises(RuntimeError): - router.run(text="test") + warm_up_mock.side_effect = lambda: setattr(router, "pipeline", hf_pipeline_mock) + router.run(text="test") + warm_up_mock.assert_called_once() @patch("haystack.components.routers.zero_shot_text_router.pipeline") def test_run_fails_with_non_string_input(self, hf_pipeline_mock):