Steve Canny 50d75c47d3
rfctr(part): add new decorator to replace four (#3650)
**Summary**
In preparation for pluggable auto-partitioners, add a new metadata
decorator to replace the four existing ones.

**Additional Context**
"Global" metadata items, those applied to all element on all
partitioners, are applied using a decorator.

Currently there are four decorators where there only needs to be one.
Consolidate those into a single metadata decorator.
One or two additional behaviors of the new decorator will allow us to
remove decorators from delegating partitioners which is a prerequisite
for pluggable auto-partitioners.
2024-09-25 23:15:50 +00:00

283 lines
8.8 KiB
Python

# pyright: reportPrivateUsage=false
"""Unit-test suite for the `unstructured.partition.lang` module."""
from __future__ import annotations
import os
import pathlib
import pytest
from test_unstructured.unit_utils import LogCaptureFixture
from unstructured.documents.elements import (
NarrativeText,
PageBreak,
)
from unstructured.partition.common.lang import (
_clean_ocr_languages_arg,
_convert_language_code_to_pytesseract_lang_code,
apply_lang_metadata,
check_language_args,
detect_languages,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)
DIRECTORY = pathlib.Path(__file__).parent.resolve()
EXAMPLE_DOCS_DIRECTORY = os.path.join(DIRECTORY, "..", "..", "example-docs")
def test_prepare_languages_for_tesseract_with_one_language():
languages = ["en"]
assert prepare_languages_for_tesseract(languages) == "eng"
def test_prepare_languages_for_tesseract_with_duplicated_languages():
languages = ["en", "eng"]
assert prepare_languages_for_tesseract(languages) == "eng"
def test_prepare_languages_for_tesseract_special_case():
languages = ["osd"]
assert prepare_languages_for_tesseract(languages) == "osd"
languages = ["equ"]
assert prepare_languages_for_tesseract(languages) == "equ"
def test_prepare_languages_for_tesseract_removes_empty_inputs():
languages = ["kbd", "es"]
assert prepare_languages_for_tesseract(languages) == "spa+spa_old"
def test_prepare_languages_for_tesseract_includes_variants():
languages = ["chi"]
assert prepare_languages_for_tesseract(languages) == "chi_sim+chi_sim_vert+chi_tra+chi_tra_vert"
def test_prepare_languages_for_tesseract_with_multiple_languages():
languages = ["ja", "afr", "en", "equ"]
assert prepare_languages_for_tesseract(languages) == "jpn+jpn_vert+afr+eng+equ"
def test_prepare_languages_for_tesseract_warns_nonstandard_language(caplog: LogCaptureFixture):
languages = ["zzz", "chi"]
assert prepare_languages_for_tesseract(languages) == "chi_sim+chi_sim_vert+chi_tra+chi_tra_vert"
assert "not a valid standard language code" in caplog.text
def test_prepare_languages_for_tesseract_warns_non_tesseract_language(caplog: LogCaptureFixture):
languages = ["kbd", "eng"]
assert prepare_languages_for_tesseract(languages) == "eng"
assert "not a language supported by Tesseract" in caplog.text
def test_prepare_languages_for_tesseract_None_languages():
with pytest.raises(ValueError, match="`languages` can not be `None`"):
languages = None
prepare_languages_for_tesseract(languages)
def test_prepare_languages_for_tesseract_no_valid_languages(caplog: LogCaptureFixture):
languages = [""]
assert prepare_languages_for_tesseract(languages) == "eng"
assert "Failed to find any valid standard language code from languages" in caplog.text
@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("eng", "en"),
("chi_sim", "ch"),
("chi_tra", "chinese_cht"),
("deu", "german"),
("jpn", "japan"),
("kor", "korean"),
],
)
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang: str, expected_lang: str):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)
def test_tesseract_to_paddle_language_invalid_codes(caplog: LogCaptureFixture):
tesseract_lang = "unsupported_lang"
assert tesseract_to_paddle_language(tesseract_lang) == "en"
assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text
@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("ENG", "en"),
("Fra", "fr"),
("DEU", "german"),
],
)
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang: str, expected_lang: str):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)
def test_detect_languages_english_auto():
text = "This is a short sentence."
assert detect_languages(text) == ["eng"]
def test_detect_languages_english_provided():
text = "This is another short sentence."
languages = ["en"]
assert detect_languages(text, languages) == ["eng"]
def test_detect_languages_korean_auto():
text = "안녕하세요"
assert detect_languages(text) == ["kor"]
def test_detect_languages_gets_multiple_languages():
text = "My lubimy mleko i chleb."
assert detect_languages(text) == ["ces", "pol", "slk"]
def test_detect_languages_warns_for_auto_and_other_input(caplog: LogCaptureFixture):
text = "This is another short sentence."
languages = ["en", "auto", "rus"]
assert detect_languages(text, languages) == ["eng"]
assert "rest of the inputted languages will be ignored" in caplog.text
def test_detect_languages_raises_TypeError_for_invalid_languages():
with pytest.raises(TypeError):
text = "This is a short sentence."
detect_languages(text, languages="eng") == ["eng"] # type: ignore
def test_apply_lang_metadata_has_no_warning_for_PageBreak(caplog: LogCaptureFixture):
elements = [NarrativeText("Sample text."), PageBreak("")]
elements = list(
apply_lang_metadata(
elements=elements,
languages=["auto"],
detect_language_per_element=True,
),
)
assert "No features in text." not in [rec.message for rec in caplog.records]
@pytest.mark.parametrize(
("lang_in", "expected_lang"),
[
("en", "eng"),
("fr", "fra"),
],
)
def test_convert_language_code_to_pytesseract_lang_code(lang_in: str, expected_lang: str):
assert expected_lang == _convert_language_code_to_pytesseract_lang_code(lang_in)
@pytest.mark.parametrize(
("input_ocr_langs", "expected"),
[
(["eng"], "eng"), # list
('"deu"', "deu"), # extra quotation marks
("[deu]", "deu"), # brackets
("['deu']", "deu"), # brackets and quotation marks
(["[deu]"], "deu"), # list, brackets and quotation marks
(['"deu"'], "deu"), # list and quotation marks
("deu+spa", "deu+spa"), # correct input
],
)
def test_clean_ocr_languages_arg(input_ocr_langs: str, expected: str):
assert _clean_ocr_languages_arg(input_ocr_langs) == expected
def test_detect_languages_handles_spelled_out_languages():
languages = detect_languages(text="Sample text longer than 5 words.", languages=["Spanish"])
assert languages == ["spa"]
@pytest.mark.parametrize(
("languages", "ocr_languages", "expected_langs"),
[
(["spa"], "deu", ["spa"]),
(["spanish"], "english", ["spa"]),
(["spa"], "[deu]", ["spa"]),
(["spa"], '"deu"', ["spa"]),
(["spa"], ["deu"], ["spa"]),
(["spa"], ["[deu]"], ["spa"]),
(["spa+deu"], "eng+deu", ["spa", "deu"]),
],
)
def test_check_language_args_uses_languages_when_ocr_languages_and_languages_are_both_defined(
languages: list[str],
ocr_languages: list[str] | str,
expected_langs: list[str],
caplog: LogCaptureFixture,
):
returned_langs = check_language_args(
languages=languages,
ocr_languages=ocr_languages,
)
for lang in returned_langs: # type: ignore
assert lang in expected_langs
assert "ocr_languages" in caplog.text
@pytest.mark.parametrize(
("languages", "ocr_languages", "expected_langs"),
[
# raise warning and use `ocr_languages` when `languages` is empty or None
([], "deu", ["deu"]),
([""], '"deu"', ["deu"]),
([""], "deu", ["deu"]),
([""], "[deu]", ["deu"]),
],
)
def test_check_language_args_uses_ocr_languages_when_languages_is_empty_or_None(
languages: list[str],
ocr_languages: str,
expected_langs: list[str],
caplog: LogCaptureFixture,
):
returned_langs = check_language_args(languages=languages, ocr_languages=ocr_languages)
for lang in returned_langs: # type: ignore
assert lang in expected_langs
assert "ocr_languages" in caplog.text
@pytest.mark.parametrize(
("languages", "ocr_languages"),
[
([], None), # how check_language_args is called from auto.partition()
([""], None),
],
)
def test_check_language_args_returns_None(
languages: list[str],
ocr_languages: None,
):
returned_langs = check_language_args(languages=languages, ocr_languages=ocr_languages)
assert returned_langs is None
def test_check_language_args_returns_auto():
assert check_language_args(languages=["eng", "spa", "auto"], ocr_languages=None) == ["auto"]
@pytest.mark.parametrize(
("languages", "ocr_languages"),
[
([], ["auto"]),
([""], "eng+auto"),
],
)
def test_check_language_args_raises_error_when_ocr_languages_contains_auto(
languages: list[str],
ocr_languages: str | list[str],
):
with pytest.raises(ValueError):
check_language_args(
languages=languages,
ocr_languages=ocr_languages,
)