mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-28 03:12:54 +00:00

* basic example of document classifier in preprocessing logic * add batch_size to TransformersDocumentClassifier * complete tutorial16 * Add latest docstring and tutorial changes * fix missing batch_size * add notebook * test for batch_size use added * add tutorial 16 to headers.py * Add latest docstring and tutorial changes * make DocumentClassifier indexing pipeline rdy * Add latest docstring and tutorial changes * flexibility improvements for DocumentClassifier in Pipelines * Add latest docstring and tutorial changes * fix index time usage * remove query from documentclassifier tests * improve classification_field resolving + minor fixes * Add latest docstring and tutorial changes * tutorial 16 extended with zero shot and pipelines * Add latest docstring and tutorial changes * install graphviz in notebook * Add latest docstring and tutorial changes * remove convert_to_dicts * Add latest docstring and tutorial changes * Fix typo * Add latest docstring and tutorial changes * remove retriever from indexing pipeline * Add latest docstring and tutorial changes * fix save_to_yaml when using FileTypeClassifier * emphasize the impact with zero shot classification * Add latest docstring and tutorial changes * adjust use_gpu to boolean in test Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
114 lines
3.7 KiB
Python
114 lines
3.7 KiB
Python
import pytest
|
|
|
|
from haystack.schema import Document
|
|
from haystack.nodes.document_classifier.base import BaseDocumentClassifier
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_document_classifier(document_classifier):
|
|
assert isinstance(document_classifier, BaseDocumentClassifier)
|
|
|
|
docs = [
|
|
Document(
|
|
content="""That's good. I like it."""*700, # extra long text to check truncation
|
|
meta={"name": "0"},
|
|
id="1",
|
|
),
|
|
Document(
|
|
content="""That's bad. I don't like it.""",
|
|
meta={"name": "1"},
|
|
id="2",
|
|
),
|
|
]
|
|
results = document_classifier.predict(documents=docs)
|
|
expected_labels = ["joy", "sadness"]
|
|
for i, doc in enumerate(results):
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_zero_shot_document_classifier(zero_shot_document_classifier):
|
|
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
|
|
|
|
docs = [
|
|
Document(
|
|
content="""That's good. I like it."""*700, # extra long text to check truncation
|
|
meta={"name": "0"},
|
|
id="1",
|
|
),
|
|
Document(
|
|
content="""That's bad. I don't like it.""",
|
|
meta={"name": "1"},
|
|
id="2",
|
|
),
|
|
]
|
|
results = zero_shot_document_classifier.predict(documents=docs)
|
|
expected_labels = ["positive", "negative"]
|
|
for i, doc in enumerate(results):
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_document_classifier_batch_size(batched_document_classifier):
|
|
assert isinstance(batched_document_classifier, BaseDocumentClassifier)
|
|
|
|
docs = [
|
|
Document(
|
|
content="""That's good. I like it."""*700, # extra long text to check truncation
|
|
meta={"name": "0"},
|
|
id="1",
|
|
),
|
|
Document(
|
|
content="""That's bad. I don't like it.""",
|
|
meta={"name": "1"},
|
|
id="2",
|
|
),
|
|
]
|
|
results = batched_document_classifier.predict(documents=docs)
|
|
expected_labels = ["joy", "sadness"]
|
|
for i, doc in enumerate(results):
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_document_classifier_as_index_node(indexing_document_classifier):
|
|
assert isinstance(indexing_document_classifier, BaseDocumentClassifier)
|
|
|
|
docs = [
|
|
{"content":"""That's good. I like it."""*700, # extra long text to check truncation
|
|
"meta":{"name": "0"},
|
|
"id":"1",
|
|
"class_field": "That's bad."
|
|
},
|
|
{"content":"""That's bad. I like it.""",
|
|
"meta":{"name": "1"},
|
|
"id":"2",
|
|
"class_field": "That's good."
|
|
},
|
|
]
|
|
output, output_name = indexing_document_classifier.run(documents=docs, root_node="File")
|
|
expected_labels = ["sadness", "joy"]
|
|
for i, doc in enumerate(output["documents"]):
|
|
assert doc["meta"]["classification"]["label"] == expected_labels[i]
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_document_classifier_as_query_node(document_classifier):
|
|
assert isinstance(document_classifier, BaseDocumentClassifier)
|
|
|
|
docs = [
|
|
Document(
|
|
content="""That's good. I like it."""*700, # extra long text to check truncation
|
|
meta={"name": "0"},
|
|
id="1",
|
|
),
|
|
Document(
|
|
content="""That's bad. I don't like it.""",
|
|
meta={"name": "1"},
|
|
id="2",
|
|
),
|
|
]
|
|
output, output_name = document_classifier.run(documents=docs, root_node="Query")
|
|
expected_labels = ["joy", "sadness"]
|
|
for i, doc in enumerate(output["documents"]):
|
|
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i] |