# SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 from unittest.mock import patch import pytest from haystack import Document, Pipeline from haystack.components.classifiers import TransformersZeroShotDocumentClassifier from haystack.components.retrievers import InMemoryBM25Retriever from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.utils import ComponentDevice, Secret class TestTransformersZeroShotDocumentClassifier: def test_init(self): component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) assert component.labels == ["positive", "negative"] assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert component.multi_label is False assert component.pipeline is None assert component.classification_field is None def test_to_dict(self): component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) component_dict = component.to_dict() assert component_dict == { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501 "init_parameters": { "model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"], "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, "huggingface_pipeline_kwargs": { "model": "cross-encoder/nli-deberta-v3-xsmall", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", }, }, } def test_from_dict(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501 "init_parameters": { "model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"], "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, "huggingface_pipeline_kwargs": { "model": "cross-encoder/nli-deberta-v3-xsmall", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", }, }, } component = TransformersZeroShotDocumentClassifier.from_dict(data) assert component.labels == ["positive", "negative"] assert component.pipeline is None assert component.token == Secret.from_dict( {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} ) assert component.huggingface_pipeline_kwargs == { "model": "cross-encoder/nli-deberta-v3-xsmall", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", "token": None, } def test_from_dict_no_default_parameters(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501 "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]}, } component = TransformersZeroShotDocumentClassifier.from_dict(data) assert component.labels == ["positive", "negative"] assert component.pipeline is None assert component.token == Secret.from_dict( {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} ) assert component.huggingface_pipeline_kwargs == { "model": "cross-encoder/nli-deberta-v3-xsmall", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", "token": None, } @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") def test_warm_up(self, hf_pipeline_mock): component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) component.warm_up() assert component.pipeline is not None def test_run_fails_without_warm_up(self): component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) positive_documents = [Document(content="That's good. I like it.")] with pytest.raises(RuntimeError): component.run(documents=positive_documents) @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") def test_run_fails_with_non_document_input(self, hf_pipeline_mock): hf_pipeline_mock.return_value = " " component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) component.warm_up() text_list = ["That's good. I like it.", "That's bad. I don't like it."] with pytest.raises(TypeError): component.run(documents=text_list) @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") def test_run_unit(self, hf_pipeline_mock): hf_pipeline_mock.return_value = [ {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]}, {"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]}, ] component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) component.pipeline = hf_pipeline_mock positive_document = Document(content="That's good. I like it.") negative_document = Document(content="That's bad. I don't like it.") result = component.run(documents=[positive_document, negative_document]) assert component.pipeline is not None assert result["documents"][0].to_dict()["classification"]["label"] == "positive" assert result["documents"][1].to_dict()["classification"]["label"] == "negative" @pytest.mark.integration @pytest.mark.slow def test_run(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 component = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) component.warm_up() positive_document = Document(content="That's good. I like it. " * 1000) negative_document = Document(content="That's bad. I don't like it.") result = component.run(documents=[positive_document, negative_document]) assert component.pipeline is not None assert result["documents"][0].to_dict()["classification"]["label"] == "positive" assert result["documents"][1].to_dict()["classification"]["label"] == "negative" def test_serialization_and_deserialization_pipeline(self): pipeline = Pipeline() document_store = InMemoryDocumentStore() retriever = InMemoryBM25Retriever(document_store=document_store) document_classifier = TransformersZeroShotDocumentClassifier( model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] ) pipeline.add_component(instance=retriever, name="retriever") pipeline.add_component(instance=document_classifier, name="document_classifier") pipeline.connect("retriever", "document_classifier") pipeline_dump = pipeline.dumps() new_pipeline = Pipeline.loads(pipeline_dump) assert new_pipeline == pipeline