feat: LanguageClassifier (#2994)

* add lanaguage classifier node

* Fix a few bugs and general code style

* whitespace

* first draft and refactoring

* draft of classes separation

* improve base class

* fix inivisible character; add some tests

* fix and more tests

* more docs and tests

* move __init__ to base

* add transformers node; improve tests

* incorporate feedback; little fix to other node

* labels_to_languages mapping

* better docstrings

* use logger instead of logging

---------

Co-authored-by: Stanislav Zamecnik <stanislav.zamecnik@telekom.com>
Co-authored-by: anakin87 <44616784+anakin87@users.noreply.github.com>
Co-authored-by: stazam <zamecnik.stanislav@gmail.com>
This commit is contained in:
ZanSara 2023-03-13 10:30:03 +01:00 committed by GitHub
parent 405aee0cfa
commit fd3f3143d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 532 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from haystack.nodes.doc_language_classifier.langdetect import LangdetectDocumentLanguageClassifier
from haystack.nodes.doc_language_classifier.transformers import TransformersDocumentLanguageClassifier

View File

@ -0,0 +1,137 @@
import logging
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple, Any
from haystack.nodes.base import BaseComponent, Document
logger = logging.getLogger(__name__)
DEFAULT_LANGUAGES = ["en", "de", "es", "cs", "nl"]
class BaseDocumentLanguageClassifier(BaseComponent):
"""
Abstract class for Document Language Classifiers
"""
outgoing_edges = len(DEFAULT_LANGUAGES)
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
route_by_language = component_params.get("route_by_language", True)
if route_by_language is False:
return 1
languages_to_route = component_params.get("languages_to_route", DEFAULT_LANGUAGES)
return len(languages_to_route)
def __init__(self, route_by_language: bool = True, languages_to_route: Optional[List[str]] = None):
"""
:param route_by_language: whether to send Documents on a different output edge depending on their language.
:param languages_to_route: list of languages, each corresponding to a different output edge (ISO code, see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)).
"""
super().__init__()
if languages_to_route is None:
languages_to_route = DEFAULT_LANGUAGES
if route_by_language is True:
logger.info(
"languages_to_route list has not been defined. The default list will be used: %s",
languages_to_route,
)
if len(set(languages_to_route)) != len(languages_to_route):
duplicates = {lang for lang in languages_to_route if languages_to_route.count(lang) > 1}
raise ValueError(f"languages_to_route parameter can't contain duplicate values ({duplicates}).")
self.route_by_language = route_by_language
self.languages_to_route = languages_to_route
@abstractmethod
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:
pass
@abstractmethod
def predict_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> List[List[Document]]:
pass
def _get_edge_from_language(self, language: str) -> str:
return f"output_{self.languages_to_route.index(language) + 1}"
def run(self, documents: List[Document]) -> Tuple[Dict[str, List[Document]], str]: # type: ignore
"""
Run language document classifier on a list of documents.
:param documents: list of documents to detect language.
"""
docs_with_languages = self.predict(documents=documents)
output = {"documents": docs_with_languages}
if self.route_by_language is False:
return output, "output_1"
# self.route_by_language is True
languages = [doc.meta["language"] for doc in docs_with_languages]
unique_languages = list(set(languages))
if len(unique_languages) > 1:
raise ValueError(
f"If route_by_language parameter is True, Documents of multiple languages ({unique_languages}) are not allowed together. "
"If you want to route documents by language, you can call Pipeline.run() once for each Document."
)
language = unique_languages[0]
if language is None:
logger.warning(
"The model cannot detect the language of any of the documents."
"The first language in the list of supported languages will be used to route the document: %s",
self.languages_to_route[0],
)
language = self.languages_to_route[0]
if language not in self.languages_to_route:
raise ValueError(
f"'{language}' is not in the list of languages to route ({', '.join(self.languages_to_route)})."
f"You should specify them when initializing the node, using the parameter languages_to_route."
)
return output, self._get_edge_from_language(str(language))
def run_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
"""
Run language document classifier on batches of documents.
:param documents: list of lists of documents to detect language.
"""
docs_lists_with_languages = self.predict_batch(documents=documents, batch_size=batch_size)
if self.route_by_language is False:
output = {"documents": docs_lists_with_languages}
return output, "output_1"
# self.route_by_language is True
split: Dict[str, Dict[str, List[List[Document]]]] = {
f"output_{pos}": {"documents": []} for pos in range(1, len(self.languages_to_route) + 1)
}
for docs_list in docs_lists_with_languages:
languages = [doc.meta["language"] for doc in docs_list]
unique_languages = list(set(languages))
if len(unique_languages) > 1:
raise ValueError(
f"If route_by_language parameter is True, Documents of multiple languages ({unique_languages}) are not allowed together. "
"If you want to route documents by language, you can call Pipeline.run() once for each Document."
)
if unique_languages[0] is None:
logger.warning(
"The model cannot detect the language of some of the documents."
"The first language in the list of supported languages will be used to route the document: %s",
self.languages_to_route[0],
)
language: Optional[str] = self.languages_to_route[0]
language = unique_languages[0]
if language not in self.languages_to_route:
raise ValueError(
f"'{language}' is not in the list of languages to route ({', '.join(self.languages_to_route)})."
f"You should specify them when initializing the node, using the parameter languages_to_route."
)
edge_name = self._get_edge_from_language(str(language))
split[edge_name]["documents"].append(docs_list)
return split, "split"

View File

@ -0,0 +1,94 @@
import logging
from typing import List, Optional
from langdetect import LangDetectException, detect
from haystack.nodes.base import Document
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
logger = logging.getLogger(__name__)
class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
"""
Node based on the lightweight and fast [langdetect library](https://github.com/Mimino666/langdetect) for document language classification.
This node detects the languge of Documents and adds the output to the Documents metadata.
The meta field of the Document is a dictionary with the following format:
``'meta': {'name': '450_Baelor.txt', 'language': 'en'}``
- Using the document language classifier, you can directly get predictions via predict()
- You can flow the Documents to different branches depending on their language,
by setting the `route_by_language` parameter to True and specifying the `languages_to_route` parameter.
**Usage example**
```python
...
docs = [Document(content="The black dog runs across the meadow")]
doclangclassifier = LangdetectDocumentLanguageClassifier()
results = doclangclassifier.predict(documents=docs)
# print the predicted language
print(results[0].to_dict()["meta"]["language"]
**Usage example for routing**
```python
...
docs = [Document(content="My name is Ryan and I live in London"),
Document(content="Mi chiamo Matteo e vivo a Roma")]
doclangclassifier = LangdetectDocumentLanguageClassifier(
route_by_language = True,
languages_to_route = ['en','it','es']
)
for doc in docs:
doclangclassifier.run(doc)
```
"""
def __init__(self, route_by_language: bool = True, languages_to_route: Optional[List[str]] = None):
"""
:param route_by_language: whether to send Documents on a different output edge depending on their language.
:param languages_to_route: list of languages, each corresponding to a different output edge (ISO code, see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)).
"""
super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route)
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:
"""
Detect the languge of Documents and add the output to the Documents metadata.
:param documents: list of Documents to detect language.
:return: List of Documents, where Document.meta["language"] contains the predicted language
"""
if len(documents) == 0:
raise ValueError(
"LangdetectDocumentLanguageClassifier needs at least one document to predict the language."
)
if batch_size is not None:
logger.warning(
"LangdetectDocumentLanguageClassifier does not support batch_size. This parameter is ignored."
)
documents_with_language = []
for document in documents:
try:
language = detect(document.content)
except LangDetectException:
logger.warning("Langdetect cannot detect the language of document: %s", document)
language = None
document.meta["language"] = language
documents_with_language.append(document)
return documents_with_language
def predict_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> List[List[Document]]:
"""
Detect the documents language and add the output to the document's meta data.
:param documents: list of lists of Documents to detect language.
:return: List of lists of Documents, where Document.meta["language"] contains the predicted language
"""
if len(documents) == 0 or all(len(docs_list) == 0 for docs_list in documents):
raise ValueError(
"LangdetectDocumentLanguageClassifier needs at least one document to predict the language."
)
if batch_size is not None:
logger.warning(
"LangdetectDocumentLanguageClassifier does not support batch_size. This parameter is ignored."
)
return [self.predict(documents=docs_list) for docs_list in documents]

View File

@ -0,0 +1,178 @@
import logging
from typing import List, Optional, Union, Dict
import itertools
import torch
from tqdm.auto import tqdm
from transformers import pipeline
from haystack.nodes.base import Document
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
from haystack.modeling.utils import initialize_device_settings
logger = logging.getLogger(__name__)
class TransformersDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
"""
Transformer based model for document language classification using the HuggingFace's transformers framework
(https://github.com/huggingface/transformers).
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
This node detects the languge of Documents and adds the output to the Documents metadata.
The meta field of the Document is a dictionary with the following format:
``'meta': {'name': '450_Baelor.txt', 'language': 'en'}``
- Using the document language classifier, you can directly get predictions via predict()
- You can flow the Documents to different branches depending on their language,
by setting the `route_by_language` parameter to True and specifying the `languages_to_route` parameter.
**Usage example**
```python
...
docs = [Document(content="The black dog runs across the meadow")]
doclangclassifier = TransformersDocumentLanguageClassifier()
results = doclangclassifier.predict(documents=docs)
# print the predicted language
print(results[0].to_dict()["meta"]["language"]
**Usage example for routing**
```python
...
docs = [Document(content="My name is Ryan and I live in London"),
Document(content="Mi chiamo Matteo e vivo a Roma")]
doclangclassifier = TransformersDocumentLanguageClassifier(
route_by_language = True,
languages_to_route = ['en','it','es']
)
for doc in docs:
doclangclassifier.run(doc)
```
"""
def __init__(
self,
route_by_language: bool = True,
languages_to_route: Optional[List[str]] = None,
labels_to_languages_mapping: Optional[Dict[str, str]] = None,
model_name_or_path: str = "papluca/xlm-roberta-base-language-detection",
model_version: Optional[str] = None,
tokenizer: Optional[str] = None,
use_gpu: bool = True,
batch_size: int = 16,
progress_bar: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
):
"""
Load a language detection model from Transformers.
See https://huggingface.co/models for full list of available models.
Language detection models: https://huggingface.co/models?search=language%20detection
:param route_by_language: whether to send Documents on a different output edge depending on their language.
:param languages_to_route: list of languages, each corresponding to a different output edge (for the list of the supported languages, see the model card of the chosen model).
:param labels_to_languages_mapping: some Transformers models do not return language names but generic labels. In this case, you can provide a mapping indicating a language for each label. For example: {"LABEL_1": "ar", "LABEL_2": "bg", ...}.
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'papluca/xlm-roberta-base-language-detection'.
See https://huggingface.co/models for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: Whether to use GPU (if available).
:param batch_size: Number of Documents to be processed at a time.
:param progress_bar: Whether to show a progress bar while processing.
:param use_auth_token: The API token used to download private models from Huggingface.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
Additional information can be found here
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
"""
super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route)
resolved_devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
if len(resolved_devices) > 1:
logger.warning(
"Multiple devices are not supported in %s inference, using the first device %s.",
self.__class__.__name__,
resolved_devices[0],
)
if tokenizer is None:
tokenizer = model_name_or_path
self.model = pipeline(
task="text-classification",
model=model_name_or_path,
tokenizer=tokenizer,
device=resolved_devices[0],
revision=model_version,
top_k=1,
use_auth_token=use_auth_token,
)
self.batch_size = batch_size
self.progress_bar = progress_bar
self.labels_to_languages_mapping = labels_to_languages_mapping or {}
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:
"""
Detect the languge of Documents and add the output to the Documents metadata.
:param documents: list of Documents to detect language.
:param batch_size: The number of Documents to classify at a time.
:return: List of Documents, where Document.meta["language"] contains the predicted language
"""
if len(documents) == 0:
raise ValueError(
"TransformersDocumentLanguageClassifier needs at least one document to predict the language."
)
if batch_size is None:
batch_size = self.batch_size
texts = [doc.content for doc in documents]
batches = self._get_batches(texts, batch_size=batch_size)
predictions = []
pb = tqdm(total=len(texts), disable=not self.progress_bar, desc="Predicting the language of documents")
for batch in batches:
batched_prediction = self.model(batch, top_k=1, truncation=True)
predictions.extend(batched_prediction)
pb.update(len(batch))
pb.close()
for prediction, doc in zip(predictions, documents):
label = prediction[0]["label"]
# replace the label with the language, if present in the mapping
language = self.labels_to_languages_mapping.get(label, label)
doc.meta["language"] = language
return documents
def predict_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> List[List[Document]]:
"""
Detect the documents language and add the output to the document's meta data.
:param documents: list of lists of Documents to detect language.
:return: List of lists of Documents, where Document.meta["language"] contains the predicted language
"""
if len(documents) == 0 or all(len(docs_list) == 0 for docs_list in documents):
raise ValueError(
"TransformersDocumentLanguageClassifier needs at least one document to predict the language."
)
if batch_size is None:
batch_size = self.batch_size
flattened_documents = list(itertools.chain.from_iterable(documents))
docs_with_preds = self.predict(flattened_documents, batch_size=batch_size)
# Group documents together
grouped_documents = []
for docs_list in documents:
grouped_documents.append(docs_with_preds[: len(docs_list)])
docs_with_preds = docs_with_preds[len(docs_list) :]
return grouped_documents
def _get_batches(self, items, batch_size):
if batch_size is None:
yield items
return
for index in range(0, len(items), batch_size):
yield items[index : index + batch_size]

View File

@ -0,0 +1,121 @@
import pytest
import logging
from haystack.schema import Document
from haystack.nodes.doc_language_classifier import (
LangdetectDocumentLanguageClassifier,
TransformersDocumentLanguageClassifier,
)
LANGUAGES_TO_ROUTE = ["en", "es", "it"]
DOCUMENTS = [
Document(content="My name is Matteo and I live in Rome"),
Document(content="Mi chiamo Matteo e vivo a Roma"),
Document(content="Mi nombre es Matteo y vivo en Roma"),
]
EXPECTED_LANGUAGES = ["en", "it", "es"]
EXPECTED_OUTPUT_EDGES = ["output_1", "output_3", "output_2"]
@pytest.fixture(params=["langdetect", "transformers"])
def doclangclassifier(request):
if request.param == "langdetect":
return LangdetectDocumentLanguageClassifier(route_by_language=True, languages_to_route=LANGUAGES_TO_ROUTE)
elif request.param == "transformers":
return TransformersDocumentLanguageClassifier(
route_by_language=True,
languages_to_route=LANGUAGES_TO_ROUTE,
model_name_or_path="jb2k/bert-base-multilingual-cased-language-detection",
labels_to_languages_mapping={"LABEL_11": "en", "LABEL_22": "it", "LABEL_38": "es"},
)
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_predict(doclangclassifier):
results = doclangclassifier.predict(documents=DOCUMENTS)
for doc, expected_language in zip(results, EXPECTED_LANGUAGES):
assert doc.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["transformers"], indirect=True)
def test_transformers_doclangclassifier_predict_wo_mapping(doclangclassifier):
doclangclassifier.labels_to_languages_mapping = {}
expected_labels = ["LABEL_11", "LABEL_22", "LABEL_38"]
results = doclangclassifier.predict(documents=DOCUMENTS)
for doc, expected_label in zip(results, expected_labels):
assert doc.to_dict()["meta"]["language"] == expected_label
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_predict_batch(doclangclassifier):
results = doclangclassifier.predict_batch(documents=[DOCUMENTS, DOCUMENTS[:2]])
expected_languages = [EXPECTED_LANGUAGES, EXPECTED_LANGUAGES[:2]]
for lst_docs, lst_expected_languages in zip(results, expected_languages):
for doc, expected_language in zip(lst_docs, lst_expected_languages):
assert doc.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_not_route(doclangclassifier):
doclangclassifier.route_by_language = False
results, edge = doclangclassifier.run(documents=DOCUMENTS)
assert edge == "output_1"
for doc, expected_language in zip(results["documents"], EXPECTED_LANGUAGES):
assert doc.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_route(doclangclassifier):
for doc, expected_language, expected_edge in zip(DOCUMENTS, EXPECTED_LANGUAGES, EXPECTED_OUTPUT_EDGES):
result, edge = doclangclassifier.run(documents=[doc])
document = result["documents"][0]
assert edge == expected_edge
assert document.to_dict()["meta"]["language"] == expected_language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_route_fail_on_mixed_languages(doclangclassifier):
with pytest.raises(ValueError, match="Documents of multiple languages"):
doclangclassifier.run(documents=DOCUMENTS)
# not testing transformers because current models always predict a language
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect"], indirect=True)
def test_doclangclassifier_run_route_cannot_detect_language(doclangclassifier, caplog):
doc_unidentifiable_lang = Document("01234, 56789, ")
with caplog.at_level(logging.INFO):
results, edge = doclangclassifier.run(documents=[doc_unidentifiable_lang])
assert "The model cannot detect the language of any of the documents." in caplog.text
assert edge == "output_1"
assert results["documents"][0].to_dict()["meta"]["language"] is None
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_route_fail_on_language_not_in_list(doclangclassifier, caplog):
doc_other_lang = Document("Meu nome é Matteo e moro em Roma")
with pytest.raises(ValueError, match="is not in the list of languages to route"):
doclangclassifier.run(documents=[doc_other_lang])
@pytest.mark.integration
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
def test_doclangclassifier_run_batch(doclangclassifier):
docs = [[doc] for doc in DOCUMENTS]
results, split_edge = doclangclassifier.run_batch(documents=docs)
assert split_edge == "split"
for edge, result in results.items():
document = result["documents"][0][0]
num_document = DOCUMENTS.index(document)
expected_language = EXPECTED_LANGUAGES[num_document]
assert edge == EXPECTED_OUTPUT_EDGES[num_document]
assert document.to_dict()["meta"]["language"] == expected_language