mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-06 08:30:31 +00:00

* Fix types in test_run.py * Get test_run.py to pass fmt-check * Add test_run to mypy checks * Update test folder to pass ruff linting * Fix merge * Fix HF tests * Fix hf test * Try to fix tests * Another attempt * minor fix * fix SentenceTransformersDiversityRanker * skip integrations tests due to model unavailable on HF inference --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
170 lines
8.2 KiB
Python
170 lines
8.2 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from haystack import Document, Pipeline
|
|
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
|
from haystack.components.retrievers import InMemoryBM25Retriever
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
from haystack.utils import ComponentDevice, Secret
|
|
|
|
|
|
class TestTransformersZeroShotDocumentClassifier:
|
|
def test_init(self):
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
assert component.labels == ["positive", "negative"]
|
|
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
|
assert component.multi_label is False
|
|
assert component.pipeline is None
|
|
assert component.classification_field is None
|
|
|
|
def test_to_dict(self):
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
component_dict = component.to_dict()
|
|
assert component_dict == {
|
|
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
|
|
"init_parameters": {
|
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
|
"labels": ["positive", "negative"],
|
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
},
|
|
},
|
|
}
|
|
|
|
def test_from_dict(self, monkeypatch):
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
|
data = {
|
|
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
|
|
"init_parameters": {
|
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
|
"labels": ["positive", "negative"],
|
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
|
"huggingface_pipeline_kwargs": {
|
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
},
|
|
},
|
|
}
|
|
component = TransformersZeroShotDocumentClassifier.from_dict(data)
|
|
assert component.labels == ["positive", "negative"]
|
|
assert component.pipeline is None
|
|
assert component.token == Secret.from_dict(
|
|
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
|
)
|
|
assert component.huggingface_pipeline_kwargs == {
|
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
"token": None,
|
|
}
|
|
|
|
def test_from_dict_no_default_parameters(self, monkeypatch):
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
|
data = {
|
|
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
|
|
"init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
|
|
}
|
|
component = TransformersZeroShotDocumentClassifier.from_dict(data)
|
|
assert component.labels == ["positive", "negative"]
|
|
assert component.pipeline is None
|
|
assert component.token == Secret.from_dict(
|
|
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
|
)
|
|
assert component.huggingface_pipeline_kwargs == {
|
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
|
"device": ComponentDevice.resolve_device(None).to_hf(),
|
|
"task": "zero-shot-classification",
|
|
"token": None,
|
|
}
|
|
|
|
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
|
|
def test_warm_up(self, hf_pipeline_mock):
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
component.warm_up()
|
|
assert component.pipeline is not None
|
|
|
|
def test_run_fails_without_warm_up(self):
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
positive_documents = [Document(content="That's good. I like it.")]
|
|
with pytest.raises(RuntimeError):
|
|
component.run(documents=positive_documents)
|
|
|
|
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
|
|
def test_run_fails_with_non_document_input(self, hf_pipeline_mock):
|
|
hf_pipeline_mock.return_value = " "
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
component.warm_up()
|
|
text_list = ["That's good. I like it.", "That's bad. I don't like it."]
|
|
with pytest.raises(TypeError):
|
|
component.run(documents=text_list)
|
|
|
|
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
|
|
def test_run_unit(self, hf_pipeline_mock):
|
|
hf_pipeline_mock.return_value = [
|
|
{"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]},
|
|
{"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]},
|
|
]
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
component.pipeline = hf_pipeline_mock
|
|
positive_document = Document(content="That's good. I like it.")
|
|
negative_document = Document(content="That's bad. I don't like it.")
|
|
result = component.run(documents=[positive_document, negative_document])
|
|
assert component.pipeline is not None
|
|
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
|
|
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.slow
|
|
def test_run(self, monkeypatch):
|
|
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
|
|
component = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
component.warm_up()
|
|
positive_document = Document(content="That's good. I like it. " * 1000)
|
|
negative_document = Document(content="That's bad. I don't like it.")
|
|
result = component.run(documents=[positive_document, negative_document])
|
|
assert component.pipeline is not None
|
|
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
|
|
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
|
|
|
|
def test_serialization_and_deserialization_pipeline(self):
|
|
pipeline = Pipeline()
|
|
document_store = InMemoryDocumentStore()
|
|
retriever = InMemoryBM25Retriever(document_store=document_store)
|
|
document_classifier = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
|
)
|
|
|
|
pipeline.add_component(instance=retriever, name="retriever")
|
|
pipeline.add_component(instance=document_classifier, name="document_classifier")
|
|
pipeline.connect("retriever", "document_classifier")
|
|
pipeline_dump = pipeline.dumps()
|
|
|
|
new_pipeline = Pipeline.loads(pipeline_dump)
|
|
|
|
assert new_pipeline == pipeline
|