feat: Add TextLanguageClassifier 2.0 (#6026)

* draft TextLanguageClassifier

* implement language detection with langdetect

* add unit test for logging message

* reno

* pylint

* change input from List[str] to str

* remove empty output connections

* add from_dict/to_dict tests

* mark example usage as python code
This commit is contained in:
Julian Risch 2023-10-13 10:30:49 +02:00 committed by GitHub
parent 110aacdc35
commit b507f1a124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 155 additions and 1 deletions

View File

@ -1,3 +1,4 @@
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter
from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier
__all__ = ["TextDocumentSplitter"]
__all__ = ["TextDocumentSplitter", "TextLanguageClassifier"]

View File

@ -0,0 +1,85 @@
import logging
from typing import List, Dict, Any, Optional
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install langdetect'") as langdetect_import:
import langdetect
@component
class TextLanguageClassifier:
"""
Routes a text input onto one of different output connections depending on its language.
This is useful for routing queries to different models in a pipeline depending on their language.
The set of supported languages can be specified.
For routing Documents based on their language use the related DocumentLanguageClassifier component.
Example usage in a retrieval pipeline that passes only English language queries to the retriever:
```python
document_store = MemoryDocumentStore()
p = Pipeline()
p.add_component(instance=TextLanguageClassifier(), name="text_language_classifier")
p.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
p.connect("text_language_classifier.en", "retriever.query")
p.run({"text_language_classifier": {"text": "What's your query?"}})
```
"""
def __init__(self, languages: Optional[List[str]] = None):
"""
:param languages: A list of languages in ISO code, each corresponding to a different output connection (see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). By default, only ["en"] is supported and texts of any other language are routed to "unmatched".
"""
langdetect_import.check()
if not languages:
languages = ["en"]
self.languages = languages
component.set_output_types(self, unmatched=str, **{language: str for language in languages})
def run(self, text: str) -> Dict[str, str]:
"""
Run the TextLanguageClassifier. This method routes the text one of different edges based on its language.
If the text does not match any of the languages specified at initialization, it is routed to
a connection named "unmatched".
:param text: A str to route to one of different edges.
"""
if not isinstance(text, str):
raise TypeError(
"TextLanguageClassifier expects a str as input. In case you want to classify a document, please use the DocumentLanguageClassifier."
)
output: Dict[str, str] = {}
detected_language = self.detect_language(text)
if detected_language in self.languages:
output[detected_language] = text
else:
output["unmatched"] = text
return output
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, languages=self.languages)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextLanguageClassifier":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def detect_language(self, text: str) -> Optional[str]:
try:
language = langdetect.detect(text)
except langdetect.LangDetectException:
logger.warning("Langdetect cannot detect the language of text: %s", text)
language = None
return language

View File

@ -0,0 +1,4 @@
---
preview:
- |
Add TextLanguageClassifier component so that an input string, for example a query, can be routed to different components based on the detected language.

View File

@ -0,0 +1,64 @@
import logging
import pytest
from haystack.preview import Document
from haystack.preview.components.preprocessors import TextLanguageClassifier
class TestTextLanguageClassifier:
@pytest.mark.unit
def test_to_dict(self):
component = TextLanguageClassifier(languages=["en", "de"])
data = component.to_dict()
assert data == {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
@pytest.mark.unit
def test_from_dict(self):
data = {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
component = TextLanguageClassifier.from_dict(data)
assert component.languages == ["en", "de"]
@pytest.mark.unit
def test_non_string_input(self):
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
classifier = TextLanguageClassifier()
classifier.run(text=Document(text="This is an english sentence."))
@pytest.mark.unit
def test_list_of_string(self):
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
classifier = TextLanguageClassifier()
classifier.run(text=["This is an english sentence."])
@pytest.mark.unit
def test_empty_string(self):
classifier = TextLanguageClassifier()
result = classifier.run(text="")
assert result == {"unmatched": ""}
@pytest.mark.unit
def test_detect_language(self):
classifier = TextLanguageClassifier()
detected_language = classifier.detect_language("This is an english sentence.")
assert detected_language == "en"
@pytest.mark.unit
def test_route_to_en(self):
classifier = TextLanguageClassifier()
english_sentence = "This is an english sentence."
result = classifier.run(text=english_sentence)
assert result == {"en": english_sentence}
@pytest.mark.unit
def test_route_to_unmatched(self):
classifier = TextLanguageClassifier()
german_sentence = "Ein deutscher Satz ohne Verb."
result = classifier.run(text=german_sentence)
assert result == {"unmatched": german_sentence}
@pytest.mark.unit
def test_warning_if_no_language_detected(self, caplog):
with caplog.at_level(logging.WARNING):
classifier = TextLanguageClassifier()
classifier.run(text=".")
assert "Langdetect cannot detect the language of text: ." in caplog.text