2021-10-01 11:22:56 +02:00
|
|
|
import pytest
|
|
|
|
|
2021-10-25 15:50:23 +02:00
|
|
|
from haystack.schema import Document
|
|
|
|
from haystack.nodes.document_classifier.base import BaseDocumentClassifier
|
2021-10-01 11:22:56 +02:00
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2021-10-01 11:22:56 +02:00
|
|
|
def test_document_classifier(document_classifier):
|
|
|
|
assert isinstance(document_classifier, BaseDocumentClassifier)
|
|
|
|
|
|
|
|
docs = [
|
|
|
|
Document(
|
2022-02-03 13:43:18 +01:00
|
|
|
content="""That's good. I like it.""" * 700, # extra long text to check truncation
|
2021-10-01 11:22:56 +02:00
|
|
|
meta={"name": "0"},
|
|
|
|
id="1",
|
|
|
|
),
|
2022-03-07 19:25:33 +01:00
|
|
|
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
2021-10-01 11:22:56 +02:00
|
|
|
]
|
|
|
|
results = document_classifier.predict(documents=docs)
|
|
|
|
expected_labels = ["joy", "sadness"]
|
|
|
|
for i, doc in enumerate(results):
|
|
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
|
|
|
2022-09-21 13:16:03 +02:00
|
|
|
@pytest.mark.integration
|
|
|
|
def test_document_classifier_details(document_classifier):
|
|
|
|
|
|
|
|
docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")]
|
|
|
|
results = document_classifier.predict(documents=docs)
|
|
|
|
for doc in results:
|
|
|
|
assert "details" in doc.meta["classification"]
|
|
|
|
assert len(doc.meta["classification"]["details"]) == 2 # top_k = 2
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2022-05-11 11:11:00 +02:00
|
|
|
def test_document_classifier_batch_single_doc_list(document_classifier):
|
|
|
|
docs = [
|
|
|
|
Document(content="""That's good. I like it.""", meta={"name": "0"}, id="1"),
|
|
|
|
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
|
|
|
]
|
|
|
|
results = document_classifier.predict_batch(documents=docs)
|
|
|
|
expected_labels = ["joy", "sadness"]
|
|
|
|
for i, doc in enumerate(results):
|
|
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2022-05-11 11:11:00 +02:00
|
|
|
def test_document_classifier_batch_multiple_doc_lists(document_classifier):
|
|
|
|
docs = [
|
|
|
|
Document(content="""That's good. I like it.""", meta={"name": "0"}, id="1"),
|
|
|
|
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
|
|
|
]
|
|
|
|
results = document_classifier.predict_batch(documents=[docs, docs])
|
|
|
|
assert len(results) == 2 # 2 Document lists
|
|
|
|
expected_labels = ["joy", "sadness"]
|
|
|
|
for i, doc in enumerate(results[0]):
|
|
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2021-10-01 11:22:56 +02:00
|
|
|
def test_zero_shot_document_classifier(zero_shot_document_classifier):
|
|
|
|
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
|
|
|
|
|
|
|
|
docs = [
|
|
|
|
Document(
|
2022-02-03 13:43:18 +01:00
|
|
|
content="""That's good. I like it.""" * 700, # extra long text to check truncation
|
2021-10-01 11:22:56 +02:00
|
|
|
meta={"name": "0"},
|
|
|
|
id="1",
|
|
|
|
),
|
2022-03-07 19:25:33 +01:00
|
|
|
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
2021-10-01 11:22:56 +02:00
|
|
|
]
|
|
|
|
results = zero_shot_document_classifier.predict(documents=docs)
|
|
|
|
expected_labels = ["positive", "negative"]
|
|
|
|
for i, doc in enumerate(results):
|
|
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
2021-11-09 18:43:00 +01:00
|
|
|
|
|
|
|
|
2022-09-21 13:16:03 +02:00
|
|
|
@pytest.mark.integration
|
|
|
|
def test_zero_shot_document_classifier_details(zero_shot_document_classifier):
|
|
|
|
|
|
|
|
docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")]
|
|
|
|
results = zero_shot_document_classifier.predict(documents=docs)
|
|
|
|
for doc in results:
|
|
|
|
assert "details" in doc.meta["classification"]
|
|
|
|
assert len(doc.meta["classification"]["details"]) == 2 # n_labels = 2
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2021-11-09 18:43:00 +01:00
|
|
|
def test_document_classifier_batch_size(batched_document_classifier):
|
|
|
|
assert isinstance(batched_document_classifier, BaseDocumentClassifier)
|
|
|
|
|
|
|
|
docs = [
|
|
|
|
Document(
|
2022-02-03 13:43:18 +01:00
|
|
|
content="""That's good. I like it.""" * 700, # extra long text to check truncation
|
2021-11-09 18:43:00 +01:00
|
|
|
meta={"name": "0"},
|
|
|
|
id="1",
|
|
|
|
),
|
2022-03-07 19:25:33 +01:00
|
|
|
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
2021-11-09 18:43:00 +01:00
|
|
|
]
|
|
|
|
results = batched_document_classifier.predict(documents=docs)
|
|
|
|
expected_labels = ["joy", "sadness"]
|
|
|
|
for i, doc in enumerate(results):
|
|
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2021-11-09 18:43:00 +01:00
|
|
|
def test_document_classifier_as_index_node(indexing_document_classifier):
|
|
|
|
assert isinstance(indexing_document_classifier, BaseDocumentClassifier)
|
|
|
|
|
|
|
|
docs = [
|
2022-02-03 13:43:18 +01:00
|
|
|
{
|
|
|
|
"content": """That's good. I like it.""" * 700, # extra long text to check truncation
|
|
|
|
"meta": {"name": "0"},
|
|
|
|
"id": "1",
|
|
|
|
"class_field": "That's bad.",
|
2021-11-09 18:43:00 +01:00
|
|
|
},
|
2022-02-03 13:43:18 +01:00
|
|
|
{"content": """That's bad. I like it.""", "meta": {"name": "1"}, "id": "2", "class_field": "That's good."},
|
2021-11-09 18:43:00 +01:00
|
|
|
]
|
|
|
|
output, output_name = indexing_document_classifier.run(documents=docs, root_node="File")
|
|
|
|
expected_labels = ["sadness", "joy"]
|
|
|
|
for i, doc in enumerate(output["documents"]):
|
|
|
|
assert doc["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2021-11-09 18:43:00 +01:00
|
|
|
def test_document_classifier_as_query_node(document_classifier):
|
|
|
|
assert isinstance(document_classifier, BaseDocumentClassifier)
|
|
|
|
|
|
|
|
docs = [
|
|
|
|
Document(
|
2022-02-03 13:43:18 +01:00
|
|
|
content="""That's good. I like it.""" * 700, # extra long text to check truncation
|
2021-11-09 18:43:00 +01:00
|
|
|
meta={"name": "0"},
|
|
|
|
id="1",
|
|
|
|
),
|
2022-03-07 19:25:33 +01:00
|
|
|
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
2021-11-09 18:43:00 +01:00
|
|
|
]
|
|
|
|
output, output_name = document_classifier.run(documents=docs, root_node="Query")
|
|
|
|
expected_labels = ["joy", "sadness"]
|
|
|
|
for i, doc in enumerate(output["documents"]):
|
2022-02-03 13:43:18 +01:00
|
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|