haystack/test/nodes/test_document_classifier.py

121 lines
4.6 KiB
Python
Raw Normal View History

import pytest
Refactoring of the `haystack` package (#1624) * Files moved, imports all broken * Fix most imports and docstrings into * Fix the paths to the modules in the API docs * Add latest docstring and tutorial changes * Add a few pipelines that were lost in the inports * Fix a bunch of mypy warnings * Add latest docstring and tutorial changes * Create a file_classifier module * Add docs for file_classifier * Fixed most circular imports, now the REST API can start * Add latest docstring and tutorial changes * Tackling more mypy issues * Reintroduce from FARM and fix last mypy issues hopefully * Re-enable old-style imports * Fix some more import from the top-level package in an attempt to sort out circular imports * Fix some imports in tests to new-style to prevent failed class equalities from breaking tests * Change document_store into document_stores * Update imports in tutorials * Add latest docstring and tutorial changes * Probably fixes summarizer tests * Improve the old-style import allowing module imports (should work) * Try to fix the docs * Remove dedicated KnowledgeGraph page from autodocs * Remove dedicated GraphRetriever page from autodocs * Fix generate_docstrings.sh with an updated list of yaml files to look for * Fix some more modules in the docs * Fix the document stores docs too * Fix a small issue on Tutorial14 * Add latest docstring and tutorial changes * Add deprecation warning to old-style imports * Remove stray folder and import Dict into dense.py * Change import path for MLFlowLogger * Add old loggers path to the import path aliases * Fix debug output of convert_ipynb.py * Fix circular import on BaseRetriever * Missed one merge block * re-run tutorial 5 * Fix imports in tutorial 5 * Re-enable squad_to_dpr CLI from the root package and move get_batches_from_generator into document_stores.base * Add latest docstring and tutorial changes * Fix typo in utils __init__ * Fix a few more imports * Fix benchmarks too * New-style imports in test_knowledge_graph * Rollback setup.py * Rollback squad_to_dpr too Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2021-10-25 15:50:23 +02:00
from haystack.schema import Document
from haystack.nodes.document_classifier.base import BaseDocumentClassifier
@pytest.mark.slow
def test_document_classifier(document_classifier):
assert isinstance(document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it.""" * 700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
]
results = document_classifier.predict(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_batch_single_doc_list(document_classifier):
docs = [
Document(content="""That's good. I like it.""", meta={"name": "0"}, id="1"),
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
]
results = document_classifier.predict_batch(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_batch_multiple_doc_lists(document_classifier):
docs = [
Document(content="""That's good. I like it.""", meta={"name": "0"}, id="1"),
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
]
results = document_classifier.predict_batch(documents=[docs, docs])
assert len(results) == 2 # 2 Document lists
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results[0]):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_zero_shot_document_classifier(zero_shot_document_classifier):
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it.""" * 700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
]
results = zero_shot_document_classifier.predict(documents=docs)
expected_labels = ["positive", "negative"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_batch_size(batched_document_classifier):
assert isinstance(batched_document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it.""" * 700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
]
results = batched_document_classifier.predict(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_as_index_node(indexing_document_classifier):
assert isinstance(indexing_document_classifier, BaseDocumentClassifier)
docs = [
{
"content": """That's good. I like it.""" * 700, # extra long text to check truncation
"meta": {"name": "0"},
"id": "1",
"class_field": "That's bad.",
},
{"content": """That's bad. I like it.""", "meta": {"name": "1"}, "id": "2", "class_field": "That's good."},
]
output, output_name = indexing_document_classifier.run(documents=docs, root_node="File")
expected_labels = ["sadness", "joy"]
for i, doc in enumerate(output["documents"]):
assert doc["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_document_classifier_as_query_node(document_classifier):
assert isinstance(document_classifier, BaseDocumentClassifier)
docs = [
Document(
content="""That's good. I like it.""" * 700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
]
output, output_name = document_classifier.run(documents=docs, root_node="Query")
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(output["documents"]):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]