diff --git a/CHANGELOG.md b/CHANGELOG.md index 96460f04d..67473dfae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +## 0.12.2-dev0 + +### Enhancements + +### Features + +### Fixes + ## 0.12.1 ### Enhancements diff --git a/test_unstructured/partition/test_lang.py b/test_unstructured/partition/test_lang.py index b85fbab8f..c262a7c7a 100644 --- a/test_unstructured/partition/test_lang.py +++ b/test_unstructured/partition/test_lang.py @@ -5,6 +5,7 @@ from unstructured.documents.elements import ( PageBreak, ) from unstructured.partition.lang import ( + _clean_ocr_languages_arg, _convert_language_code_to_pytesseract_lang_code, apply_lang_metadata, detect_languages, @@ -126,6 +127,22 @@ def test_convert_language_code_to_pytesseract_lang_code(lang_in, expected_lang): assert expected_lang == _convert_language_code_to_pytesseract_lang_code(lang_in) +@pytest.mark.parametrize( + ("input_ocr_langs", "expected"), + [ + (["eng"], "eng"), # list + ('"deu"', "deu"), # extra quotation marks + ("[deu]", "deu"), # brackets + ("['deu']", "deu"), # brackets and quotation marks + (["[deu]"], "deu"), # list, brackets and quotation marks + (['"deu"'], "deu"), # list and quotation marks + ("deu+spa", "deu+spa"), # correct input + ], +) +def test_clean_ocr_languages_arg(input_ocr_langs, expected): + assert _clean_ocr_languages_arg(input_ocr_langs) == expected + + def test_detect_languages_handles_spelled_out_languages(): languages = detect_languages(text="Sample text longer than 5 words.", languages=["Spanish"]) assert languages == ["spa"] diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 17e72186f..f1055c5ab 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.12.1" # pragma: no cover +__version__ = "0.12.2-dev0" # pragma: no cover diff --git a/unstructured/partition/lang.py b/unstructured/partition/lang.py index 0a59ac333..d2aa6a6a2 100644 --- a/unstructured/partition/lang.py +++ b/unstructured/partition/lang.py @@ -1,5 +1,5 @@ import re -from typing import Iterable, Iterator, List, Optional +from typing import Iterable, Iterator, List, Optional, Union import iso639 from langdetect import DetectorFactory, detect_langs, lang_detect_exception @@ -380,3 +380,19 @@ def apply_lang_metadata( yield e else: yield e + + +def _clean_ocr_languages_arg(ocr_languages: Union[List[str], str]) -> str: + """Fix common incorrect definitions for ocr_languages: + defining it as a list, adding extra quotation marks, adding brackets. + Returns a single string of ocr_languages""" + # extract from list + if isinstance(ocr_languages, list): + ocr_languages = "+".join(ocr_languages) + + # remove extra quotations + ocr_languages = re.sub(r"[\"']", "", ocr_languages) + # remove brackets + ocr_languages = re.sub(r"[\[\]]", "", ocr_languages) + + return ocr_languages