haystack/test/test_document_classifier.py
Julian Risch 24483d7bad
TransformersDocumentClassifier replacing FARMClassifier (#1540)
* Initial draft of TransformersClassifier

* Add transformers classifier implementation

* Add test for SentenceTransformersClassifier

* Add truncation and corresponding test case to Classifier

* Add zero-shot classification and test

* Add document classifier documentation

* Add latest docstring and tutorial changes

* print meta data with print_documents()

* Add latest docstring and tutorial changes

* Remove top_k param from Classifier usage example

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2021-10-01 11:22:56 +02:00

49 lines
1.5 KiB
Python

import pytest
from haystack import Document
from haystack.document_classifier.base import BaseDocumentClassifier
@pytest.mark.slow
def test_document_classifier(document_classifier):
assert isinstance(document_classifier, BaseDocumentClassifier)
docs = [
Document(
text="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
text="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = document_classifier.predict(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_zero_shot_document_classifier(zero_shot_document_classifier):
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
docs = [
Document(
text="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
text="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = zero_shot_document_classifier.predict(documents=docs)
expected_labels = ["positive", "negative"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]