mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 23:04:02 +00:00
feat: LanguageClassifier (#2994)
* add lanaguage classifier node * Fix a few bugs and general code style * whitespace * first draft and refactoring * draft of classes separation * improve base class * fix inivisible character; add some tests * fix and more tests * more docs and tests * move __init__ to base * add transformers node; improve tests * incorporate feedback; little fix to other node * labels_to_languages mapping * better docstrings * use logger instead of logging --------- Co-authored-by: Stanislav Zamecnik <stanislav.zamecnik@telekom.com> Co-authored-by: anakin87 <44616784+anakin87@users.noreply.github.com> Co-authored-by: stazam <zamecnik.stanislav@gmail.com>
This commit is contained in:
parent
405aee0cfa
commit
fd3f3143d4
2
haystack/nodes/doc_language_classifier/__init__.py
Normal file
2
haystack/nodes/doc_language_classifier/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from haystack.nodes.doc_language_classifier.langdetect import LangdetectDocumentLanguageClassifier
|
||||||
|
from haystack.nodes.doc_language_classifier.transformers import TransformersDocumentLanguageClassifier
|
||||||
137
haystack/nodes/doc_language_classifier/base.py
Normal file
137
haystack/nodes/doc_language_classifier/base.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
|
||||||
|
from haystack.nodes.base import BaseComponent, Document
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_LANGUAGES = ["en", "de", "es", "cs", "nl"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDocumentLanguageClassifier(BaseComponent):
|
||||||
|
"""
|
||||||
|
Abstract class for Document Language Classifiers
|
||||||
|
"""
|
||||||
|
|
||||||
|
outgoing_edges = len(DEFAULT_LANGUAGES)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||||
|
route_by_language = component_params.get("route_by_language", True)
|
||||||
|
if route_by_language is False:
|
||||||
|
return 1
|
||||||
|
languages_to_route = component_params.get("languages_to_route", DEFAULT_LANGUAGES)
|
||||||
|
return len(languages_to_route)
|
||||||
|
|
||||||
|
def __init__(self, route_by_language: bool = True, languages_to_route: Optional[List[str]] = None):
|
||||||
|
"""
|
||||||
|
:param route_by_language: whether to send Documents on a different output edge depending on their language.
|
||||||
|
:param languages_to_route: list of languages, each corresponding to a different output edge (ISO code, see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if languages_to_route is None:
|
||||||
|
languages_to_route = DEFAULT_LANGUAGES
|
||||||
|
if route_by_language is True:
|
||||||
|
logger.info(
|
||||||
|
"languages_to_route list has not been defined. The default list will be used: %s",
|
||||||
|
languages_to_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(set(languages_to_route)) != len(languages_to_route):
|
||||||
|
duplicates = {lang for lang in languages_to_route if languages_to_route.count(lang) > 1}
|
||||||
|
raise ValueError(f"languages_to_route parameter can't contain duplicate values ({duplicates}).")
|
||||||
|
|
||||||
|
self.route_by_language = route_by_language
|
||||||
|
self.languages_to_route = languages_to_route
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def predict_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> List[List[Document]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_edge_from_language(self, language: str) -> str:
|
||||||
|
return f"output_{self.languages_to_route.index(language) + 1}"
|
||||||
|
|
||||||
|
def run(self, documents: List[Document]) -> Tuple[Dict[str, List[Document]], str]: # type: ignore
|
||||||
|
"""
|
||||||
|
Run language document classifier on a list of documents.
|
||||||
|
|
||||||
|
:param documents: list of documents to detect language.
|
||||||
|
"""
|
||||||
|
docs_with_languages = self.predict(documents=documents)
|
||||||
|
output = {"documents": docs_with_languages}
|
||||||
|
|
||||||
|
if self.route_by_language is False:
|
||||||
|
return output, "output_1"
|
||||||
|
|
||||||
|
# self.route_by_language is True
|
||||||
|
languages = [doc.meta["language"] for doc in docs_with_languages]
|
||||||
|
unique_languages = list(set(languages))
|
||||||
|
if len(unique_languages) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"If route_by_language parameter is True, Documents of multiple languages ({unique_languages}) are not allowed together. "
|
||||||
|
"If you want to route documents by language, you can call Pipeline.run() once for each Document."
|
||||||
|
)
|
||||||
|
language = unique_languages[0]
|
||||||
|
if language is None:
|
||||||
|
logger.warning(
|
||||||
|
"The model cannot detect the language of any of the documents."
|
||||||
|
"The first language in the list of supported languages will be used to route the document: %s",
|
||||||
|
self.languages_to_route[0],
|
||||||
|
)
|
||||||
|
language = self.languages_to_route[0]
|
||||||
|
if language not in self.languages_to_route:
|
||||||
|
raise ValueError(
|
||||||
|
f"'{language}' is not in the list of languages to route ({', '.join(self.languages_to_route)})."
|
||||||
|
f"You should specify them when initializing the node, using the parameter languages_to_route."
|
||||||
|
)
|
||||||
|
return output, self._get_edge_from_language(str(language))
|
||||||
|
|
||||||
|
def run_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||||
|
"""
|
||||||
|
Run language document classifier on batches of documents.
|
||||||
|
|
||||||
|
:param documents: list of lists of documents to detect language.
|
||||||
|
"""
|
||||||
|
docs_lists_with_languages = self.predict_batch(documents=documents, batch_size=batch_size)
|
||||||
|
|
||||||
|
if self.route_by_language is False:
|
||||||
|
output = {"documents": docs_lists_with_languages}
|
||||||
|
return output, "output_1"
|
||||||
|
|
||||||
|
# self.route_by_language is True
|
||||||
|
split: Dict[str, Dict[str, List[List[Document]]]] = {
|
||||||
|
f"output_{pos}": {"documents": []} for pos in range(1, len(self.languages_to_route) + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for docs_list in docs_lists_with_languages:
|
||||||
|
languages = [doc.meta["language"] for doc in docs_list]
|
||||||
|
unique_languages = list(set(languages))
|
||||||
|
if len(unique_languages) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"If route_by_language parameter is True, Documents of multiple languages ({unique_languages}) are not allowed together. "
|
||||||
|
"If you want to route documents by language, you can call Pipeline.run() once for each Document."
|
||||||
|
)
|
||||||
|
if unique_languages[0] is None:
|
||||||
|
logger.warning(
|
||||||
|
"The model cannot detect the language of some of the documents."
|
||||||
|
"The first language in the list of supported languages will be used to route the document: %s",
|
||||||
|
self.languages_to_route[0],
|
||||||
|
)
|
||||||
|
language: Optional[str] = self.languages_to_route[0]
|
||||||
|
language = unique_languages[0]
|
||||||
|
if language not in self.languages_to_route:
|
||||||
|
raise ValueError(
|
||||||
|
f"'{language}' is not in the list of languages to route ({', '.join(self.languages_to_route)})."
|
||||||
|
f"You should specify them when initializing the node, using the parameter languages_to_route."
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_name = self._get_edge_from_language(str(language))
|
||||||
|
split[edge_name]["documents"].append(docs_list)
|
||||||
|
return split, "split"
|
||||||
94
haystack/nodes/doc_language_classifier/langdetect.py
Normal file
94
haystack/nodes/doc_language_classifier/langdetect.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langdetect import LangDetectException, detect
|
||||||
|
|
||||||
|
from haystack.nodes.base import Document
|
||||||
|
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
|
||||||
|
"""
|
||||||
|
Node based on the lightweight and fast [langdetect library](https://github.com/Mimino666/langdetect) for document language classification.
|
||||||
|
This node detects the languge of Documents and adds the output to the Documents metadata.
|
||||||
|
The meta field of the Document is a dictionary with the following format:
|
||||||
|
``'meta': {'name': '450_Baelor.txt', 'language': 'en'}``
|
||||||
|
- Using the document language classifier, you can directly get predictions via predict()
|
||||||
|
- You can flow the Documents to different branches depending on their language,
|
||||||
|
by setting the `route_by_language` parameter to True and specifying the `languages_to_route` parameter.
|
||||||
|
**Usage example**
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
docs = [Document(content="The black dog runs across the meadow")]
|
||||||
|
|
||||||
|
doclangclassifier = LangdetectDocumentLanguageClassifier()
|
||||||
|
results = doclangclassifier.predict(documents=docs)
|
||||||
|
|
||||||
|
# print the predicted language
|
||||||
|
print(results[0].to_dict()["meta"]["language"]
|
||||||
|
|
||||||
|
**Usage example for routing**
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
docs = [Document(content="My name is Ryan and I live in London"),
|
||||||
|
Document(content="Mi chiamo Matteo e vivo a Roma")]
|
||||||
|
|
||||||
|
doclangclassifier = LangdetectDocumentLanguageClassifier(
|
||||||
|
route_by_language = True,
|
||||||
|
languages_to_route = ['en','it','es']
|
||||||
|
)
|
||||||
|
for doc in docs:
|
||||||
|
doclangclassifier.run(doc)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, route_by_language: bool = True, languages_to_route: Optional[List[str]] = None):
|
||||||
|
"""
|
||||||
|
:param route_by_language: whether to send Documents on a different output edge depending on their language.
|
||||||
|
:param languages_to_route: list of languages, each corresponding to a different output edge (ISO code, see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)).
|
||||||
|
"""
|
||||||
|
super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route)
|
||||||
|
|
||||||
|
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Detect the languge of Documents and add the output to the Documents metadata.
|
||||||
|
:param documents: list of Documents to detect language.
|
||||||
|
:return: List of Documents, where Document.meta["language"] contains the predicted language
|
||||||
|
"""
|
||||||
|
if len(documents) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"LangdetectDocumentLanguageClassifier needs at least one document to predict the language."
|
||||||
|
)
|
||||||
|
if batch_size is not None:
|
||||||
|
logger.warning(
|
||||||
|
"LangdetectDocumentLanguageClassifier does not support batch_size. This parameter is ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
documents_with_language = []
|
||||||
|
for document in documents:
|
||||||
|
try:
|
||||||
|
language = detect(document.content)
|
||||||
|
except LangDetectException:
|
||||||
|
logger.warning("Langdetect cannot detect the language of document: %s", document)
|
||||||
|
language = None
|
||||||
|
document.meta["language"] = language
|
||||||
|
documents_with_language.append(document)
|
||||||
|
return documents_with_language
|
||||||
|
|
||||||
|
def predict_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> List[List[Document]]:
|
||||||
|
"""
|
||||||
|
Detect the documents language and add the output to the document's meta data.
|
||||||
|
:param documents: list of lists of Documents to detect language.
|
||||||
|
:return: List of lists of Documents, where Document.meta["language"] contains the predicted language
|
||||||
|
"""
|
||||||
|
if len(documents) == 0 or all(len(docs_list) == 0 for docs_list in documents):
|
||||||
|
raise ValueError(
|
||||||
|
"LangdetectDocumentLanguageClassifier needs at least one document to predict the language."
|
||||||
|
)
|
||||||
|
if batch_size is not None:
|
||||||
|
logger.warning(
|
||||||
|
"LangdetectDocumentLanguageClassifier does not support batch_size. This parameter is ignored."
|
||||||
|
)
|
||||||
|
return [self.predict(documents=docs_list) for docs_list in documents]
|
||||||
178
haystack/nodes/doc_language_classifier/transformers.py
Normal file
178
haystack/nodes/doc_language_classifier/transformers.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional, Union, Dict
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from haystack.nodes.base import Document
|
||||||
|
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
|
||||||
|
from haystack.modeling.utils import initialize_device_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
|
||||||
|
"""
|
||||||
|
Transformer based model for document language classification using the HuggingFace's transformers framework
|
||||||
|
(https://github.com/huggingface/transformers).
|
||||||
|
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
|
||||||
|
This node detects the languge of Documents and adds the output to the Documents metadata.
|
||||||
|
The meta field of the Document is a dictionary with the following format:
|
||||||
|
``'meta': {'name': '450_Baelor.txt', 'language': 'en'}``
|
||||||
|
- Using the document language classifier, you can directly get predictions via predict()
|
||||||
|
- You can flow the Documents to different branches depending on their language,
|
||||||
|
by setting the `route_by_language` parameter to True and specifying the `languages_to_route` parameter.
|
||||||
|
**Usage example**
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
docs = [Document(content="The black dog runs across the meadow")]
|
||||||
|
|
||||||
|
doclangclassifier = TransformersDocumentLanguageClassifier()
|
||||||
|
results = doclangclassifier.predict(documents=docs)
|
||||||
|
|
||||||
|
# print the predicted language
|
||||||
|
print(results[0].to_dict()["meta"]["language"]
|
||||||
|
|
||||||
|
**Usage example for routing**
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
docs = [Document(content="My name is Ryan and I live in London"),
|
||||||
|
Document(content="Mi chiamo Matteo e vivo a Roma")]
|
||||||
|
|
||||||
|
doclangclassifier = TransformersDocumentLanguageClassifier(
|
||||||
|
route_by_language = True,
|
||||||
|
languages_to_route = ['en','it','es']
|
||||||
|
)
|
||||||
|
for doc in docs:
|
||||||
|
doclangclassifier.run(doc)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
route_by_language: bool = True,
|
||||||
|
languages_to_route: Optional[List[str]] = None,
|
||||||
|
labels_to_languages_mapping: Optional[Dict[str, str]] = None,
|
||||||
|
model_name_or_path: str = "papluca/xlm-roberta-base-language-detection",
|
||||||
|
model_version: Optional[str] = None,
|
||||||
|
tokenizer: Optional[str] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
batch_size: int = 16,
|
||||||
|
progress_bar: bool = True,
|
||||||
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
|
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load a language detection model from Transformers.
|
||||||
|
See https://huggingface.co/models for full list of available models.
|
||||||
|
Language detection models: https://huggingface.co/models?search=language%20detection
|
||||||
|
|
||||||
|
:param route_by_language: whether to send Documents on a different output edge depending on their language.
|
||||||
|
:param languages_to_route: list of languages, each corresponding to a different output edge (for the list of the supported languages, see the model card of the chosen model).
|
||||||
|
:param labels_to_languages_mapping: some Transformers models do not return language names but generic labels. In this case, you can provide a mapping indicating a language for each label. For example: {"LABEL_1": "ar", "LABEL_2": "bg", ...}.
|
||||||
|
|
||||||
|
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'papluca/xlm-roberta-base-language-detection'.
|
||||||
|
See https://huggingface.co/models for full list of available models.
|
||||||
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||||
|
:param tokenizer: Name of the tokenizer (usually the same as model)
|
||||||
|
:param use_gpu: Whether to use GPU (if available).
|
||||||
|
:param batch_size: Number of Documents to be processed at a time.
|
||||||
|
:param progress_bar: Whether to show a progress bar while processing.
|
||||||
|
:param use_auth_token: The API token used to download private models from Huggingface.
|
||||||
|
If this parameter is set to `True`, then the token generated when running
|
||||||
|
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||||
|
Additional information can be found here
|
||||||
|
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||||
|
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
|
||||||
|
A list containing torch device objects and/or strings is supported (For example
|
||||||
|
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||||
|
parameter is not used and a single cpu device is used for inference.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route)
|
||||||
|
|
||||||
|
resolved_devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
|
||||||
|
if len(resolved_devices) > 1:
|
||||||
|
logger.warning(
|
||||||
|
"Multiple devices are not supported in %s inference, using the first device %s.",
|
||||||
|
self.__class__.__name__,
|
||||||
|
resolved_devices[0],
|
||||||
|
)
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = model_name_or_path
|
||||||
|
|
||||||
|
self.model = pipeline(
|
||||||
|
task="text-classification",
|
||||||
|
model=model_name_or_path,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=resolved_devices[0],
|
||||||
|
revision=model_version,
|
||||||
|
top_k=1,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.progress_bar = progress_bar
|
||||||
|
self.labels_to_languages_mapping = labels_to_languages_mapping or {}
|
||||||
|
|
||||||
|
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Detect the languge of Documents and add the output to the Documents metadata.
|
||||||
|
:param documents: list of Documents to detect language.
|
||||||
|
:param batch_size: The number of Documents to classify at a time.
|
||||||
|
:return: List of Documents, where Document.meta["language"] contains the predicted language
|
||||||
|
"""
|
||||||
|
if len(documents) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"TransformersDocumentLanguageClassifier needs at least one document to predict the language."
|
||||||
|
)
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
texts = [doc.content for doc in documents]
|
||||||
|
batches = self._get_batches(texts, batch_size=batch_size)
|
||||||
|
predictions = []
|
||||||
|
pb = tqdm(total=len(texts), disable=not self.progress_bar, desc="Predicting the language of documents")
|
||||||
|
for batch in batches:
|
||||||
|
batched_prediction = self.model(batch, top_k=1, truncation=True)
|
||||||
|
predictions.extend(batched_prediction)
|
||||||
|
pb.update(len(batch))
|
||||||
|
pb.close()
|
||||||
|
for prediction, doc in zip(predictions, documents):
|
||||||
|
label = prediction[0]["label"]
|
||||||
|
# replace the label with the language, if present in the mapping
|
||||||
|
language = self.labels_to_languages_mapping.get(label, label)
|
||||||
|
doc.meta["language"] = language
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def predict_batch(self, documents: List[List[Document]], batch_size: Optional[int] = None) -> List[List[Document]]:
|
||||||
|
"""
|
||||||
|
Detect the documents language and add the output to the document's meta data.
|
||||||
|
:param documents: list of lists of Documents to detect language.
|
||||||
|
:return: List of lists of Documents, where Document.meta["language"] contains the predicted language
|
||||||
|
"""
|
||||||
|
if len(documents) == 0 or all(len(docs_list) == 0 for docs_list in documents):
|
||||||
|
raise ValueError(
|
||||||
|
"TransformersDocumentLanguageClassifier needs at least one document to predict the language."
|
||||||
|
)
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
flattened_documents = list(itertools.chain.from_iterable(documents))
|
||||||
|
docs_with_preds = self.predict(flattened_documents, batch_size=batch_size)
|
||||||
|
|
||||||
|
# Group documents together
|
||||||
|
grouped_documents = []
|
||||||
|
for docs_list in documents:
|
||||||
|
grouped_documents.append(docs_with_preds[: len(docs_list)])
|
||||||
|
docs_with_preds = docs_with_preds[len(docs_list) :]
|
||||||
|
|
||||||
|
return grouped_documents
|
||||||
|
|
||||||
|
def _get_batches(self, items, batch_size):
|
||||||
|
if batch_size is None:
|
||||||
|
yield items
|
||||||
|
return
|
||||||
|
for index in range(0, len(items), batch_size):
|
||||||
|
yield items[index : index + batch_size]
|
||||||
121
test/nodes/test_doc_language_classifier.py
Normal file
121
test/nodes/test_doc_language_classifier.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import pytest
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from haystack.schema import Document
|
||||||
|
from haystack.nodes.doc_language_classifier import (
|
||||||
|
LangdetectDocumentLanguageClassifier,
|
||||||
|
TransformersDocumentLanguageClassifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
LANGUAGES_TO_ROUTE = ["en", "es", "it"]
|
||||||
|
DOCUMENTS = [
|
||||||
|
Document(content="My name is Matteo and I live in Rome"),
|
||||||
|
Document(content="Mi chiamo Matteo e vivo a Roma"),
|
||||||
|
Document(content="Mi nombre es Matteo y vivo en Roma"),
|
||||||
|
]
|
||||||
|
|
||||||
|
EXPECTED_LANGUAGES = ["en", "it", "es"]
|
||||||
|
EXPECTED_OUTPUT_EDGES = ["output_1", "output_3", "output_2"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["langdetect", "transformers"])
|
||||||
|
def doclangclassifier(request):
|
||||||
|
if request.param == "langdetect":
|
||||||
|
return LangdetectDocumentLanguageClassifier(route_by_language=True, languages_to_route=LANGUAGES_TO_ROUTE)
|
||||||
|
elif request.param == "transformers":
|
||||||
|
return TransformersDocumentLanguageClassifier(
|
||||||
|
route_by_language=True,
|
||||||
|
languages_to_route=LANGUAGES_TO_ROUTE,
|
||||||
|
model_name_or_path="jb2k/bert-base-multilingual-cased-language-detection",
|
||||||
|
labels_to_languages_mapping={"LABEL_11": "en", "LABEL_22": "it", "LABEL_38": "es"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_predict(doclangclassifier):
|
||||||
|
results = doclangclassifier.predict(documents=DOCUMENTS)
|
||||||
|
for doc, expected_language in zip(results, EXPECTED_LANGUAGES):
|
||||||
|
assert doc.to_dict()["meta"]["language"] == expected_language
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["transformers"], indirect=True)
|
||||||
|
def test_transformers_doclangclassifier_predict_wo_mapping(doclangclassifier):
|
||||||
|
doclangclassifier.labels_to_languages_mapping = {}
|
||||||
|
expected_labels = ["LABEL_11", "LABEL_22", "LABEL_38"]
|
||||||
|
results = doclangclassifier.predict(documents=DOCUMENTS)
|
||||||
|
for doc, expected_label in zip(results, expected_labels):
|
||||||
|
assert doc.to_dict()["meta"]["language"] == expected_label
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_predict_batch(doclangclassifier):
|
||||||
|
results = doclangclassifier.predict_batch(documents=[DOCUMENTS, DOCUMENTS[:2]])
|
||||||
|
expected_languages = [EXPECTED_LANGUAGES, EXPECTED_LANGUAGES[:2]]
|
||||||
|
for lst_docs, lst_expected_languages in zip(results, expected_languages):
|
||||||
|
for doc, expected_language in zip(lst_docs, lst_expected_languages):
|
||||||
|
assert doc.to_dict()["meta"]["language"] == expected_language
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_run_not_route(doclangclassifier):
|
||||||
|
doclangclassifier.route_by_language = False
|
||||||
|
results, edge = doclangclassifier.run(documents=DOCUMENTS)
|
||||||
|
assert edge == "output_1"
|
||||||
|
for doc, expected_language in zip(results["documents"], EXPECTED_LANGUAGES):
|
||||||
|
assert doc.to_dict()["meta"]["language"] == expected_language
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_run_route(doclangclassifier):
|
||||||
|
for doc, expected_language, expected_edge in zip(DOCUMENTS, EXPECTED_LANGUAGES, EXPECTED_OUTPUT_EDGES):
|
||||||
|
result, edge = doclangclassifier.run(documents=[doc])
|
||||||
|
document = result["documents"][0]
|
||||||
|
|
||||||
|
assert edge == expected_edge
|
||||||
|
assert document.to_dict()["meta"]["language"] == expected_language
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_run_route_fail_on_mixed_languages(doclangclassifier):
|
||||||
|
with pytest.raises(ValueError, match="Documents of multiple languages"):
|
||||||
|
doclangclassifier.run(documents=DOCUMENTS)
|
||||||
|
|
||||||
|
|
||||||
|
# not testing transformers because current models always predict a language
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect"], indirect=True)
|
||||||
|
def test_doclangclassifier_run_route_cannot_detect_language(doclangclassifier, caplog):
|
||||||
|
doc_unidentifiable_lang = Document("01234, 56789, ")
|
||||||
|
with caplog.at_level(logging.INFO):
|
||||||
|
results, edge = doclangclassifier.run(documents=[doc_unidentifiable_lang])
|
||||||
|
assert "The model cannot detect the language of any of the documents." in caplog.text
|
||||||
|
assert edge == "output_1"
|
||||||
|
assert results["documents"][0].to_dict()["meta"]["language"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_run_route_fail_on_language_not_in_list(doclangclassifier, caplog):
|
||||||
|
doc_other_lang = Document("Meu nome é Matteo e moro em Roma")
|
||||||
|
with pytest.raises(ValueError, match="is not in the list of languages to route"):
|
||||||
|
doclangclassifier.run(documents=[doc_other_lang])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True)
|
||||||
|
def test_doclangclassifier_run_batch(doclangclassifier):
|
||||||
|
docs = [[doc] for doc in DOCUMENTS]
|
||||||
|
results, split_edge = doclangclassifier.run_batch(documents=docs)
|
||||||
|
assert split_edge == "split"
|
||||||
|
for edge, result in results.items():
|
||||||
|
document = result["documents"][0][0]
|
||||||
|
num_document = DOCUMENTS.index(document)
|
||||||
|
expected_language = EXPECTED_LANGUAGES[num_document]
|
||||||
|
assert edge == EXPECTED_OUTPUT_EDGES[num_document]
|
||||||
|
assert document.to_dict()["meta"]["language"] == expected_language
|
||||||
Loading…
x
Reference in New Issue
Block a user