diff --git a/haystack/preview/components/preprocessors/__init__.py b/haystack/preview/components/preprocessors/__init__.py index 33a0e2cd1..b01167573 100644 --- a/haystack/preview/components/preprocessors/__init__.py +++ b/haystack/preview/components/preprocessors/__init__.py @@ -1,3 +1,4 @@ from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter +from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier -__all__ = ["TextDocumentSplitter"] +__all__ = ["TextDocumentSplitter", "TextLanguageClassifier"] diff --git a/haystack/preview/components/preprocessors/text_language_classifier.py b/haystack/preview/components/preprocessors/text_language_classifier.py new file mode 100644 index 000000000..fe010d22c --- /dev/null +++ b/haystack/preview/components/preprocessors/text_language_classifier.py @@ -0,0 +1,85 @@ +import logging +from typing import List, Dict, Any, Optional + +from haystack.preview import component, default_from_dict, default_to_dict +from haystack.preview.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + +with LazyImport("Run 'pip install langdetect'") as langdetect_import: + import langdetect + + +@component +class TextLanguageClassifier: + """ + Routes a text input onto one of different output connections depending on its language. + This is useful for routing queries to different models in a pipeline depending on their language. + The set of supported languages can be specified. + For routing Documents based on their language use the related DocumentLanguageClassifier component. + + Example usage in a retrieval pipeline that passes only English language queries to the retriever: + + ```python + document_store = MemoryDocumentStore() + p = Pipeline() + p.add_component(instance=TextLanguageClassifier(), name="text_language_classifier") + p.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever") + p.connect("text_language_classifier.en", "retriever.query") + p.run({"text_language_classifier": {"text": "What's your query?"}}) + ``` + """ + + def __init__(self, languages: Optional[List[str]] = None): + """ + :param languages: A list of languages in ISO code, each corresponding to a different output connection (see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). By default, only ["en"] is supported and texts of any other language are routed to "unmatched". + """ + langdetect_import.check() + if not languages: + languages = ["en"] + self.languages = languages + component.set_output_types(self, unmatched=str, **{language: str for language in languages}) + + def run(self, text: str) -> Dict[str, str]: + """ + Run the TextLanguageClassifier. This method routes the text one of different edges based on its language. + If the text does not match any of the languages specified at initialization, it is routed to + a connection named "unmatched". + + :param text: A str to route to one of different edges. + """ + if not isinstance(text, str): + raise TypeError( + "TextLanguageClassifier expects a str as input. In case you want to classify a document, please use the DocumentLanguageClassifier." + ) + + output: Dict[str, str] = {} + + detected_language = self.detect_language(text) + if detected_language in self.languages: + output[detected_language] = text + else: + output["unmatched"] = text + + return output + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, languages=self.languages) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TextLanguageClassifier": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + + def detect_language(self, text: str) -> Optional[str]: + try: + language = langdetect.detect(text) + except langdetect.LangDetectException: + logger.warning("Langdetect cannot detect the language of text: %s", text) + language = None + return language diff --git a/releasenotes/notes/text-language-classifier-0d1e1a97f1bb8ac6.yaml b/releasenotes/notes/text-language-classifier-0d1e1a97f1bb8ac6.yaml new file mode 100644 index 000000000..84430525e --- /dev/null +++ b/releasenotes/notes/text-language-classifier-0d1e1a97f1bb8ac6.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Add TextLanguageClassifier component so that an input string, for example a query, can be routed to different components based on the detected language. diff --git a/test/preview/components/preprocessors/test_text_language_classifier.py b/test/preview/components/preprocessors/test_text_language_classifier.py new file mode 100644 index 000000000..e2b1d5437 --- /dev/null +++ b/test/preview/components/preprocessors/test_text_language_classifier.py @@ -0,0 +1,64 @@ +import logging +import pytest + +from haystack.preview import Document +from haystack.preview.components.preprocessors import TextLanguageClassifier + + +class TestTextLanguageClassifier: + @pytest.mark.unit + def test_to_dict(self): + component = TextLanguageClassifier(languages=["en", "de"]) + data = component.to_dict() + assert data == {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}} + + @pytest.mark.unit + def test_from_dict(self): + data = {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}} + component = TextLanguageClassifier.from_dict(data) + assert component.languages == ["en", "de"] + + @pytest.mark.unit + def test_non_string_input(self): + with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."): + classifier = TextLanguageClassifier() + classifier.run(text=Document(text="This is an english sentence.")) + + @pytest.mark.unit + def test_list_of_string(self): + with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."): + classifier = TextLanguageClassifier() + classifier.run(text=["This is an english sentence."]) + + @pytest.mark.unit + def test_empty_string(self): + classifier = TextLanguageClassifier() + result = classifier.run(text="") + assert result == {"unmatched": ""} + + @pytest.mark.unit + def test_detect_language(self): + classifier = TextLanguageClassifier() + detected_language = classifier.detect_language("This is an english sentence.") + assert detected_language == "en" + + @pytest.mark.unit + def test_route_to_en(self): + classifier = TextLanguageClassifier() + english_sentence = "This is an english sentence." + result = classifier.run(text=english_sentence) + assert result == {"en": english_sentence} + + @pytest.mark.unit + def test_route_to_unmatched(self): + classifier = TextLanguageClassifier() + german_sentence = "Ein deutscher Satz ohne Verb." + result = classifier.run(text=german_sentence) + assert result == {"unmatched": german_sentence} + + @pytest.mark.unit + def test_warning_if_no_language_detected(self, caplog): + with caplog.at_level(logging.WARNING): + classifier = TextLanguageClassifier() + classifier.run(text=".") + assert "Langdetect cannot detect the language of text: ." in caplog.text