haystack/test/nodes/test_doc_language_classifier.py
ZanSara fd3f3143d4
feat: LanguageClassifier (#2994)
* add lanaguage classifier node

* Fix a few bugs and general code style

* whitespace

* first draft and refactoring

* draft of classes separation

* improve base class

* fix inivisible character; add some tests

* fix and more tests

* more docs and tests

* move __init__ to base

* add transformers node; improve tests

* incorporate feedback; little fix to other node

* labels_to_languages mapping

* better docstrings

* use logger instead of logging

---------

Co-authored-by: Stanislav Zamecnik <stanislav.zamecnik@telekom.com>
Co-authored-by: anakin87 <44616784+anakin87@users.noreply.github.com>
Co-authored-by: stazam <zamecnik.stanislav@gmail.com>
2023-03-13 10:30:03 +01:00

122 lines
5.5 KiB
Python

import pytest
import logging
from haystack.schema import Document
from haystack.nodes.doc_language_classifier import (
LangdetectDocumentLanguageClassifier,
TransformersDocumentLanguageClassifier,
)
LANGUAGES_TO_ROUTE = ["en", "es", "it"]
DOCUMENTS = [
Document(content="My name is Matteo and I live in Rome"),
Document(content="Mi chiamo Matteo e vivo a Roma"),
Document(content="Mi nombre es Matteo y vivo en Roma"),
]
EXPECTED_LANGUAGES = ["en", "it", "es"]
EXPECTED_OUTPUT_EDGES = ["output_1", "output_3", "output_2"]
@pytest.fixture(params=["langdetect", "transformers"])
def doclangclassifier(request):
if request.param == "langdetect":
return LangdetectDocumentLanguageClassifier(route_by_language=True, languages_to_route=LANGUAGES_TO_ROUTE)
elif request.param == "transformers":
return TransformersDocumentLanguageClassifier(
route_by_language=True,
languages_to_route=LANGUAGES_TO_ROUTE,
model_name_or_path="jb2k/bert-base-multilingual-cased-language-detection",
labels_to_languages_mapping={"LABEL_11": "en", "LABEL_22": "it", "LABEL_38": "es"},
)
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_predict(doclangclassifier):
results = doclangclassifier.predict(documents=DOCUMENTS)
for doc, expected_language in zip(results, EXPECTED_LANGUAGES):
assert doc.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["transformers"], indirect=True)
def test_transformers_doclangclassifier_predict_wo_mapping(doclangclassifier):
doclangclassifier.labels_to_languages_mapping = {}
expected_labels = ["LABEL_11", "LABEL_22", "LABEL_38"]
results = doclangclassifier.predict(documents=DOCUMENTS)
for doc, expected_label in zip(results, expected_labels):
assert doc.to_dict()["meta"]["language"] == expected_label
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_predict_batch(doclangclassifier):
results = doclangclassifier.predict_batch(documents=[DOCUMENTS, DOCUMENTS[:2]])
expected_languages = [EXPECTED_LANGUAGES, EXPECTED_LANGUAGES[:2]]
for lst_docs, lst_expected_languages in zip(results, expected_languages):
for doc, expected_language in zip(lst_docs, lst_expected_languages):
assert doc.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_not_route(doclangclassifier):
doclangclassifier.route_by_language = False
results, edge = doclangclassifier.run(documents=DOCUMENTS)
assert edge == "output_1"
for doc, expected_language in zip(results["documents"], EXPECTED_LANGUAGES):
assert doc.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_route(doclangclassifier):
for doc, expected_language, expected_edge in zip(DOCUMENTS, EXPECTED_LANGUAGES, EXPECTED_OUTPUT_EDGES):
result, edge = doclangclassifier.run(documents=[doc])
document = result["documents"][0]
assert edge == expected_edge
assert document.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_route_fail_on_mixed_languages(doclangclassifier):
with pytest.raises(ValueError, match="Documents of multiple languages"):
doclangclassifier.run(documents=DOCUMENTS)
# not testing transformers because current models always predict a language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect"], indirect=True)
def test_doclangclassifier_run_route_cannot_detect_language(doclangclassifier, caplog):
doc_unidentifiable_lang = Document("01234, 56789, ")
with caplog.at_level(logging.INFO):
results, edge = doclangclassifier.run(documents=[doc_unidentifiable_lang])
assert "The model cannot detect the language of any of the documents." in caplog.text
assert edge == "output_1"
assert results["documents"][0].to_dict()["meta"]["language"] is None
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_route_fail_on_language_not_in_list(doclangclassifier, caplog):
doc_other_lang = Document("Meu nome é Matteo e moro em Roma")
with pytest.raises(ValueError, match="is not in the list of languages to route"):
doclangclassifier.run(documents=[doc_other_lang])
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_batch(doclangclassifier):
docs = [[doc] for doc in DOCUMENTS]
results, split_edge = doclangclassifier.run_batch(documents=docs)
assert split_edge == "split"
for edge, result in results.items():
document = result["documents"][0][0]
num_document = DOCUMENTS.index(document)
expected_language = EXPECTED_LANGUAGES[num_document]
assert edge == EXPECTED_OUTPUT_EDGES[num_document]
assert document.to_dict()["meta"]["language"] == expected_language