mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 07:17:41 +00:00
feat: Add TextLanguageClassifier 2.0 (#6026)
* draft TextLanguageClassifier * implement language detection with langdetect * add unit test for logging message * reno * pylint * change input from List[str] to str * remove empty output connections * add from_dict/to_dict tests * mark example usage as python code
This commit is contained in:
parent
110aacdc35
commit
b507f1a124
@ -1,3 +1,4 @@
|
||||
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter
|
||||
from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier
|
||||
|
||||
__all__ = ["TextDocumentSplitter"]
|
||||
__all__ = ["TextDocumentSplitter", "TextLanguageClassifier"]
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport("Run 'pip install langdetect'") as langdetect_import:
|
||||
import langdetect
|
||||
|
||||
|
||||
@component
|
||||
class TextLanguageClassifier:
|
||||
"""
|
||||
Routes a text input onto one of different output connections depending on its language.
|
||||
This is useful for routing queries to different models in a pipeline depending on their language.
|
||||
The set of supported languages can be specified.
|
||||
For routing Documents based on their language use the related DocumentLanguageClassifier component.
|
||||
|
||||
Example usage in a retrieval pipeline that passes only English language queries to the retriever:
|
||||
|
||||
```python
|
||||
document_store = MemoryDocumentStore()
|
||||
p = Pipeline()
|
||||
p.add_component(instance=TextLanguageClassifier(), name="text_language_classifier")
|
||||
p.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
p.connect("text_language_classifier.en", "retriever.query")
|
||||
p.run({"text_language_classifier": {"text": "What's your query?"}})
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, languages: Optional[List[str]] = None):
|
||||
"""
|
||||
:param languages: A list of languages in ISO code, each corresponding to a different output connection (see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). By default, only ["en"] is supported and texts of any other language are routed to "unmatched".
|
||||
"""
|
||||
langdetect_import.check()
|
||||
if not languages:
|
||||
languages = ["en"]
|
||||
self.languages = languages
|
||||
component.set_output_types(self, unmatched=str, **{language: str for language in languages})
|
||||
|
||||
def run(self, text: str) -> Dict[str, str]:
|
||||
"""
|
||||
Run the TextLanguageClassifier. This method routes the text one of different edges based on its language.
|
||||
If the text does not match any of the languages specified at initialization, it is routed to
|
||||
a connection named "unmatched".
|
||||
|
||||
:param text: A str to route to one of different edges.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(
|
||||
"TextLanguageClassifier expects a str as input. In case you want to classify a document, please use the DocumentLanguageClassifier."
|
||||
)
|
||||
|
||||
output: Dict[str, str] = {}
|
||||
|
||||
detected_language = self.detect_language(text)
|
||||
if detected_language in self.languages:
|
||||
output[detected_language] = text
|
||||
else:
|
||||
output["unmatched"] = text
|
||||
|
||||
return output
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, languages=self.languages)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TextLanguageClassifier":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def detect_language(self, text: str) -> Optional[str]:
|
||||
try:
|
||||
language = langdetect.detect(text)
|
||||
except langdetect.LangDetectException:
|
||||
logger.warning("Langdetect cannot detect the language of text: %s", text)
|
||||
language = None
|
||||
return language
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Add TextLanguageClassifier component so that an input string, for example a query, can be routed to different components based on the detected language.
|
||||
@ -0,0 +1,64 @@
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.preprocessors import TextLanguageClassifier
|
||||
|
||||
|
||||
class TestTextLanguageClassifier:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TextLanguageClassifier(languages=["en", "de"])
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
|
||||
component = TextLanguageClassifier.from_dict(data)
|
||||
assert component.languages == ["en", "de"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_string_input(self):
|
||||
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
|
||||
classifier = TextLanguageClassifier()
|
||||
classifier.run(text=Document(text="This is an english sentence."))
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list_of_string(self):
|
||||
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
|
||||
classifier = TextLanguageClassifier()
|
||||
classifier.run(text=["This is an english sentence."])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_string(self):
|
||||
classifier = TextLanguageClassifier()
|
||||
result = classifier.run(text="")
|
||||
assert result == {"unmatched": ""}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_detect_language(self):
|
||||
classifier = TextLanguageClassifier()
|
||||
detected_language = classifier.detect_language("This is an english sentence.")
|
||||
assert detected_language == "en"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_route_to_en(self):
|
||||
classifier = TextLanguageClassifier()
|
||||
english_sentence = "This is an english sentence."
|
||||
result = classifier.run(text=english_sentence)
|
||||
assert result == {"en": english_sentence}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_route_to_unmatched(self):
|
||||
classifier = TextLanguageClassifier()
|
||||
german_sentence = "Ein deutscher Satz ohne Verb."
|
||||
result = classifier.run(text=german_sentence)
|
||||
assert result == {"unmatched": german_sentence}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_warning_if_no_language_detected(self, caplog):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
classifier = TextLanguageClassifier()
|
||||
classifier.run(text=".")
|
||||
assert "Langdetect cannot detect the language of text: ." in caplog.text
|
||||
Loading…
x
Reference in New Issue
Block a user