mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-14 08:37:42 +00:00
feat: Add TextLanguageClassifier 2.0 (#6026)
* draft TextLanguageClassifier * implement language detection with langdetect * add unit test for logging message * reno * pylint * change input from List[str] to str * remove empty output connections * add from_dict/to_dict tests * mark example usage as python code
This commit is contained in:
parent
110aacdc35
commit
b507f1a124
@ -1,3 +1,4 @@
|
|||||||
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter
|
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter
|
||||||
|
from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier
|
||||||
|
|
||||||
__all__ = ["TextDocumentSplitter"]
|
__all__ = ["TextDocumentSplitter", "TextLanguageClassifier"]
|
||||||
|
|||||||
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from haystack.preview import component, default_from_dict, default_to_dict
|
||||||
|
from haystack.preview.lazy_imports import LazyImport
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
with LazyImport("Run 'pip install langdetect'") as langdetect_import:
|
||||||
|
import langdetect
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
class TextLanguageClassifier:
|
||||||
|
"""
|
||||||
|
Routes a text input onto one of different output connections depending on its language.
|
||||||
|
This is useful for routing queries to different models in a pipeline depending on their language.
|
||||||
|
The set of supported languages can be specified.
|
||||||
|
For routing Documents based on their language use the related DocumentLanguageClassifier component.
|
||||||
|
|
||||||
|
Example usage in a retrieval pipeline that passes only English language queries to the retriever:
|
||||||
|
|
||||||
|
```python
|
||||||
|
document_store = MemoryDocumentStore()
|
||||||
|
p = Pipeline()
|
||||||
|
p.add_component(instance=TextLanguageClassifier(), name="text_language_classifier")
|
||||||
|
p.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||||
|
p.connect("text_language_classifier.en", "retriever.query")
|
||||||
|
p.run({"text_language_classifier": {"text": "What's your query?"}})
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, languages: Optional[List[str]] = None):
|
||||||
|
"""
|
||||||
|
:param languages: A list of languages in ISO code, each corresponding to a different output connection (see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). By default, only ["en"] is supported and texts of any other language are routed to "unmatched".
|
||||||
|
"""
|
||||||
|
langdetect_import.check()
|
||||||
|
if not languages:
|
||||||
|
languages = ["en"]
|
||||||
|
self.languages = languages
|
||||||
|
component.set_output_types(self, unmatched=str, **{language: str for language in languages})
|
||||||
|
|
||||||
|
def run(self, text: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Run the TextLanguageClassifier. This method routes the text one of different edges based on its language.
|
||||||
|
If the text does not match any of the languages specified at initialization, it is routed to
|
||||||
|
a connection named "unmatched".
|
||||||
|
|
||||||
|
:param text: A str to route to one of different edges.
|
||||||
|
"""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
raise TypeError(
|
||||||
|
"TextLanguageClassifier expects a str as input. In case you want to classify a document, please use the DocumentLanguageClassifier."
|
||||||
|
)
|
||||||
|
|
||||||
|
output: Dict[str, str] = {}
|
||||||
|
|
||||||
|
detected_language = self.detect_language(text)
|
||||||
|
if detected_language in self.languages:
|
||||||
|
output[detected_language] = text
|
||||||
|
else:
|
||||||
|
output["unmatched"] = text
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serialize this component to a dictionary.
|
||||||
|
"""
|
||||||
|
return default_to_dict(self, languages=self.languages)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "TextLanguageClassifier":
|
||||||
|
"""
|
||||||
|
Deserialize this component from a dictionary.
|
||||||
|
"""
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
def detect_language(self, text: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
language = langdetect.detect(text)
|
||||||
|
except langdetect.LangDetectException:
|
||||||
|
logger.warning("Langdetect cannot detect the language of text: %s", text)
|
||||||
|
language = None
|
||||||
|
return language
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
preview:
|
||||||
|
- |
|
||||||
|
Add TextLanguageClassifier component so that an input string, for example a query, can be routed to different components based on the detected language.
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.preview import Document
|
||||||
|
from haystack.preview.components.preprocessors import TextLanguageClassifier
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextLanguageClassifier:
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_to_dict(self):
|
||||||
|
component = TextLanguageClassifier(languages=["en", "de"])
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_from_dict(self):
|
||||||
|
data = {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
|
||||||
|
component = TextLanguageClassifier.from_dict(data)
|
||||||
|
assert component.languages == ["en", "de"]
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_non_string_input(self):
|
||||||
|
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
classifier.run(text=Document(text="This is an english sentence."))
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_list_of_string(self):
|
||||||
|
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
classifier.run(text=["This is an english sentence."])
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_empty_string(self):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
result = classifier.run(text="")
|
||||||
|
assert result == {"unmatched": ""}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_detect_language(self):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
detected_language = classifier.detect_language("This is an english sentence.")
|
||||||
|
assert detected_language == "en"
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_route_to_en(self):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
english_sentence = "This is an english sentence."
|
||||||
|
result = classifier.run(text=english_sentence)
|
||||||
|
assert result == {"en": english_sentence}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_route_to_unmatched(self):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
german_sentence = "Ein deutscher Satz ohne Verb."
|
||||||
|
result = classifier.run(text=german_sentence)
|
||||||
|
assert result == {"unmatched": german_sentence}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_warning_if_no_language_detected(self, caplog):
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
classifier = TextLanguageClassifier()
|
||||||
|
classifier.run(text=".")
|
||||||
|
assert "Langdetect cannot detect the language of text: ." in caplog.text
|
||||||
Loading…
x
Reference in New Issue
Block a user