haystack/test/test_document_classifier.py
tstadel 14515a861b
Tutorial for DocumentClassifier at Index Time (#1697)
* basic example of document classifier in preprocessing logic

* add batch_size to TransformersDocumentClassifier

* complete tutorial16

* Add latest docstring and tutorial changes

* fix missing batch_size

* add notebook

* test for batch_size use added

* add tutorial 16 to headers.py

* Add latest docstring and tutorial changes

* make DocumentClassifier indexing pipeline rdy

* Add latest docstring and tutorial changes

* flexibility improvements for DocumentClassifier in Pipelines

* Add latest docstring and tutorial changes

* fix index time usage

* remove query from documentclassifier tests

* improve classification_field resolving + minor fixes

* Add latest docstring and tutorial changes

* tutorial 16 extended with zero shot and pipelines

* Add latest docstring and tutorial changes

* install graphviz in notebook

* Add latest docstring and tutorial changes

* remove convert_to_dicts

* Add latest docstring and tutorial changes

* Fix typo

* Add latest docstring and tutorial changes

* remove retriever from indexing pipeline

* Add latest docstring and tutorial changes

* fix save_to_yaml when using FileTypeClassifier

* emphasize the impact with zero shot classification

* Add latest docstring and tutorial changes

* adjust use_gpu to boolean in test

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
2021-11-09 18:43:00 +01:00

114 lines
3.7 KiB
Python

import pytest
from haystack.schema import Document
from haystack.nodes.document_classifier.base import BaseDocumentClassifier
@pytest.mark.slow
def test_document_classifier(document_classifier):
assert isinstance(document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
content="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = document_classifier.predict(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_zero_shot_document_classifier(zero_shot_document_classifier):
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
content="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = zero_shot_document_classifier.predict(documents=docs)
expected_labels = ["positive", "negative"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_batch_size(batched_document_classifier):
assert isinstance(batched_document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
content="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = batched_document_classifier.predict(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_as_index_node(indexing_document_classifier):
assert isinstance(indexing_document_classifier, BaseDocumentClassifier)
docs = [
{"content":"""That's good. I like it."""*700, # extra long text to check truncation
"meta":{"name": "0"},
"id":"1",
"class_field": "That's bad."
},
{"content":"""That's bad. I like it.""",
"meta":{"name": "1"},
"id":"2",
"class_field": "That's good."
},
]
output, output_name = indexing_document_classifier.run(documents=docs, root_node="File")
expected_labels = ["sadness", "joy"]
for i, doc in enumerate(output["documents"]):
assert doc["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_as_query_node(document_classifier):
assert isinstance(document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
content="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
output, output_name = document_classifier.run(documents=docs, root_node="Query")
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(output["documents"]):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]