mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2026-01-08 05:10:11 +00:00
feat: partition_pdf() support language specification for PaddleOCR (#3400)
Closes #3159. This PR extends language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`. ### Testing ``` os.environ["OCR_AGENT"] = "unstructured.partition.utils.ocr_models.paddle_ocr.OCRAgentPaddle" elements = partition_pdf( filename=<file_path>, strategy=strategy, languages=["chi_sim"], # chinese - simplified infer_table_structure=True, ) ```
This commit is contained in:
parent
6b1d5f28bb
commit
48bdf94656
@ -9,6 +9,7 @@
|
||||
|
||||
### Features
|
||||
|
||||
* **Add support for specifying OCR language to `partition_pdf()`.** Extend language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`.
|
||||
* **Add AstraDB source connector** Adds support for ingesting documents from AstraDB.
|
||||
|
||||
### Fixes
|
||||
|
||||
@ -21,6 +21,7 @@ from unstructured.partition.utils.constants import (
|
||||
Source,
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
|
||||
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
|
||||
OCRAgentTesseract,
|
||||
@ -85,10 +86,7 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch):
|
||||
image = Image.new("RGB", (100, 100))
|
||||
|
||||
ocr_agent = OCRAgentTesseract()
|
||||
ocr_layout = ocr_agent.get_layout_from_image(
|
||||
image,
|
||||
ocr_languages="eng",
|
||||
)
|
||||
ocr_layout = ocr_agent.get_layout_from_image(image)
|
||||
|
||||
expected_layout = [
|
||||
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_TESSERACT),
|
||||
@ -128,7 +126,7 @@ def mock_ocr(*args, **kwargs):
|
||||
]
|
||||
|
||||
|
||||
def monkeypatch_load_agent(language: str):
|
||||
def monkeypatch_load_agent(*args):
|
||||
class MockAgent:
|
||||
def __init__(self):
|
||||
self.ocr = mock_ocr
|
||||
@ -145,10 +143,7 @@ def test_get_ocr_layout_from_image_paddle(monkeypatch):
|
||||
|
||||
image = Image.new("RGB", (100, 100))
|
||||
|
||||
ocr_layout = OCRAgentPaddle().get_layout_from_image(
|
||||
image,
|
||||
ocr_languages="eng",
|
||||
)
|
||||
ocr_layout = OCRAgentPaddle().get_layout_from_image(image)
|
||||
|
||||
expected_layout = [
|
||||
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_PADDLE),
|
||||
@ -168,10 +163,7 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch):
|
||||
image = Image.new("RGB", (100, 100))
|
||||
|
||||
ocr_agent = OCRAgentTesseract()
|
||||
ocr_text = ocr_agent.get_text_from_image(
|
||||
image,
|
||||
ocr_languages="eng",
|
||||
)
|
||||
ocr_text = ocr_agent.get_text_from_image(image)
|
||||
|
||||
assert ocr_text == "Hello World"
|
||||
|
||||
@ -186,10 +178,7 @@ def test_get_ocr_text_from_image_paddle(monkeypatch):
|
||||
image = Image.new("RGB", (100, 100))
|
||||
|
||||
ocr_agent = OCRAgentPaddle()
|
||||
ocr_text = ocr_agent.get_text_from_image(
|
||||
image,
|
||||
ocr_languages="eng",
|
||||
)
|
||||
ocr_text = ocr_agent.get_text_from_image(image)
|
||||
|
||||
assert ocr_text == "Hello\n\nWorld\n\n!"
|
||||
|
||||
@ -251,7 +240,7 @@ def test_get_ocr_from_image_google_vision(google_vision_client):
|
||||
image = Image.new("RGB", (100, 100))
|
||||
|
||||
ocr_agent = google_vision_client
|
||||
ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng")
|
||||
ocr_text = ocr_agent.get_text_from_image(image)
|
||||
|
||||
assert ocr_text == "Hello World!"
|
||||
|
||||
@ -428,7 +417,8 @@ def mock_ocr_layout():
|
||||
|
||||
def test_get_table_tokens(mock_ocr_layout):
|
||||
with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout):
|
||||
table_tokens = ocr.get_table_tokens(table_element_image=None)
|
||||
ocr_agent = OCRAgent.get_agent(language="eng")
|
||||
table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent)
|
||||
expected_tokens = [
|
||||
{
|
||||
"bbox": [15, 25, 35, 45],
|
||||
|
||||
@ -21,6 +21,7 @@ from unstructured.partition.lang import (
|
||||
check_language_args,
|
||||
detect_languages,
|
||||
prepare_languages_for_tesseract,
|
||||
tesseract_to_paddle_language,
|
||||
)
|
||||
|
||||
DIRECTORY = pathlib.Path(__file__).parent.resolve()
|
||||
@ -84,6 +85,39 @@ def test_prepare_languages_for_tesseract_no_valid_languages(caplog):
|
||||
assert "Failed to find any valid standard language code from languages" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tesseract_lang", "expected_lang"),
|
||||
[
|
||||
("eng", "en"),
|
||||
("chi_sim", "ch"),
|
||||
("chi_tra", "chinese_cht"),
|
||||
("deu", "german"),
|
||||
("jpn", "japan"),
|
||||
("kor", "korean"),
|
||||
],
|
||||
)
|
||||
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang, expected_lang):
|
||||
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)
|
||||
|
||||
|
||||
def test_tesseract_to_paddle_language_invalid_codes(caplog):
|
||||
tesseract_lang = "unsupported_lang"
|
||||
assert tesseract_to_paddle_language(tesseract_lang) == "en"
|
||||
assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tesseract_lang", "expected_lang"),
|
||||
[
|
||||
("ENG", "en"),
|
||||
("Fra", "fr"),
|
||||
("DEU", "german"),
|
||||
],
|
||||
)
|
||||
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang, expected_lang):
|
||||
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)
|
||||
|
||||
|
||||
def test_detect_languages_english_auto():
|
||||
text = "This is a short sentence."
|
||||
assert detect_languages(text) == ["eng"]
|
||||
|
||||
@ -35,10 +35,10 @@ class DescribeOCRAgent:
|
||||
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
|
||||
get_instance_.return_value = ocr_agent_
|
||||
|
||||
ocr_agent = OCRAgent.get_agent()
|
||||
ocr_agent = OCRAgent.get_agent(language="eng")
|
||||
|
||||
_get_ocr_agent_cls_qname_.assert_called_once_with()
|
||||
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
|
||||
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
|
||||
assert ocr_agent is ocr_agent_
|
||||
|
||||
def but_it_raises_when_the_requested_agent_is_not_whitelisted(
|
||||
@ -46,7 +46,7 @@ class DescribeOCRAgent:
|
||||
):
|
||||
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
|
||||
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
|
||||
OCRAgent.get_agent()
|
||||
OCRAgent.get_agent(language="eng")
|
||||
|
||||
@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
|
||||
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
|
||||
@ -57,7 +57,7 @@ class DescribeOCRAgent:
|
||||
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
|
||||
side_effect=exception_cls,
|
||||
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
|
||||
OCRAgent.get_agent()
|
||||
OCRAgent.get_agent(language="eng")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("OCR_AGENT", "expected_value"),
|
||||
|
||||
@ -4,6 +4,7 @@ from PIL import Image
|
||||
|
||||
from unstructured.partition.pdf import convert_pdf_to_images
|
||||
from unstructured.partition.pdf_image.ocr import get_table_tokens
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
|
||||
@ -21,8 +22,10 @@ def image_or_pdf_to_dataframe(filename: str) -> pd.DataFrame:
|
||||
else:
|
||||
image = Image.open(filename).convert("RGB")
|
||||
|
||||
ocr_agent = OCRAgent.get_agent(language="eng")
|
||||
|
||||
return tables_agent.run_prediction(
|
||||
image, ocr_tokens=get_table_tokens(image), result_format="dataframe"
|
||||
image, ocr_tokens=get_table_tokens(image, ocr_agent), result_format="dataframe"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ def partition_image(
|
||||
"""
|
||||
exactly_one(filename=filename, file=file)
|
||||
|
||||
languages = check_language_args(languages or [], ocr_languages) or ["eng"]
|
||||
languages = check_language_args(languages or [], ocr_languages)
|
||||
|
||||
return partition_pdf_or_image(
|
||||
filename=filename,
|
||||
|
||||
@ -144,6 +144,63 @@ PYTESSERACT_LANG_CODES = [
|
||||
"yor",
|
||||
]
|
||||
|
||||
PYTESSERACT_TO_PADDLE_LANG_CODE_MAP = {
|
||||
"afr": "af", # Afrikaans
|
||||
"ara": "ar", # Arabic
|
||||
"aze": "az", # Azerbaijani
|
||||
"bel": "be", # Belarusian
|
||||
"bos": "bs", # Bosnian
|
||||
"bul": "bg", # Bulgarian
|
||||
"ces": "cs", # Czech
|
||||
"chi_sim": "ch", # Simplified Chinese
|
||||
"chi_tra": "chinese_cht", # Traditional Chinese
|
||||
"cym": "cy", # Welsh
|
||||
"dan": "da", # Danish
|
||||
"deu": "german", # German
|
||||
"eng": "en", # English
|
||||
"est": "et", # Estonian
|
||||
"fas": "fa", # Persian
|
||||
"fra": "fr", # French
|
||||
"gle": "ga", # Irish
|
||||
"hin": "hi", # Hindi
|
||||
"hrv": "hr", # Croatian
|
||||
"hun": "hu", # Hungarian
|
||||
"ind": "id", # Indonesian
|
||||
"isl": "is", # Icelandic
|
||||
"ita": "it", # Italian
|
||||
"jpn": "japan", # Japanese
|
||||
"kor": "korean", # Korean
|
||||
"kmr": "ku", # Kurdish
|
||||
"lat": "rs_latin", # Latin
|
||||
"lav": "lv", # Latvian
|
||||
"lit": "lt", # Lithuanian
|
||||
"mar": "mr", # Marathi
|
||||
"mlt": "mt", # Maltese
|
||||
"msa": "ms", # Malay
|
||||
"nep": "ne", # Nepali
|
||||
"nld": "nl", # Dutch
|
||||
"nor": "no", # Norwegian
|
||||
"pol": "pl", # Polish
|
||||
"por": "pt", # Portuguese
|
||||
"ron": "ro", # Romanian
|
||||
"rus": "ru", # Russian
|
||||
"slk": "sk", # Slovak
|
||||
"slv": "sl", # Slovenian
|
||||
"spa": "es", # Spanish
|
||||
"sqi": "sq", # Albanian
|
||||
"srp": "rs_cyrillic", # Serbian
|
||||
"swa": "sw", # Swahili
|
||||
"swe": "sv", # Swedish
|
||||
"tam": "ta", # Tamil
|
||||
"tel": "te", # Telugu
|
||||
"tur": "tr", # Turkish
|
||||
"uig": "ug", # Uyghur
|
||||
"ukr": "uk", # Ukrainian
|
||||
"urd": "ur", # Urdu
|
||||
"uzb": "uz", # Uzbek
|
||||
"vie": "vi", # Vietnamese
|
||||
}
|
||||
|
||||
|
||||
def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> str:
|
||||
"""
|
||||
@ -169,6 +226,25 @@ def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) ->
|
||||
return TESSERACT_LANGUAGES_SPLITTER.join(converted_languages)
|
||||
|
||||
|
||||
def tesseract_to_paddle_language(tesseract_language: str) -> str:
|
||||
"""
|
||||
Convert TesseractOCR language code to PaddleOCR language code.
|
||||
|
||||
:param tesseract_language: str, language code used in TesseractOCR
|
||||
:return: str, corresponding language code for PaddleOCR or None if not found
|
||||
"""
|
||||
|
||||
lang = PYTESSERACT_TO_PADDLE_LANG_CODE_MAP.get(tesseract_language.lower())
|
||||
if not lang:
|
||||
logger.warning(
|
||||
f"{tesseract_language} is not a language code supported by PaddleOCR, "
|
||||
f"proceeding with `en` instead."
|
||||
)
|
||||
return "en"
|
||||
|
||||
return lang
|
||||
|
||||
|
||||
def check_language_args(languages: list[str], ocr_languages: Optional[str]) -> Optional[list[str]]:
|
||||
"""Handle users defining both `ocr_languages` and `languages`, giving preference to `languages`
|
||||
and converting `ocr_languages` if needed, but defaulting to `None.
|
||||
|
||||
@ -45,7 +45,11 @@ from unstructured.partition.common import (
|
||||
ocr_data_to_elements,
|
||||
spooled_to_bytes_io_if_needed,
|
||||
)
|
||||
from unstructured.partition.lang import check_language_args, prepare_languages_for_tesseract
|
||||
from unstructured.partition.lang import (
|
||||
check_language_args,
|
||||
prepare_languages_for_tesseract,
|
||||
tesseract_to_paddle_language,
|
||||
)
|
||||
from unstructured.partition.pdf_image.analysis.bbox_visualisation import (
|
||||
AnalysisDrawer,
|
||||
FinalLayoutDrawer,
|
||||
@ -77,6 +81,7 @@ from unstructured.partition.strategies import determine_pdf_or_image_strategy, v
|
||||
from unstructured.partition.text import element_from_text
|
||||
from unstructured.partition.utils.config import env_config
|
||||
from unstructured.partition.utils.constants import (
|
||||
OCR_AGENT_PADDLE,
|
||||
SORT_MODE_BASIC,
|
||||
SORT_MODE_DONT,
|
||||
SORT_MODE_XY_CUT,
|
||||
@ -197,7 +202,7 @@ def partition_pdf(
|
||||
|
||||
exactly_one(filename=filename, file=file)
|
||||
|
||||
languages = check_language_args(languages or [], ocr_languages) or ["eng"]
|
||||
languages = check_language_args(languages or [], ocr_languages)
|
||||
|
||||
return partition_pdf_or_image(
|
||||
filename=filename,
|
||||
@ -227,7 +232,6 @@ def partition_pdf_or_image(
|
||||
include_page_breaks: bool = False,
|
||||
strategy: str = PartitionStrategy.AUTO,
|
||||
infer_table_structure: bool = False,
|
||||
ocr_languages: Optional[str] = None,
|
||||
languages: Optional[list[str]] = None,
|
||||
metadata_last_modified: Optional[str] = None,
|
||||
hi_res_model_name: Optional[str] = None,
|
||||
@ -247,6 +251,9 @@ def partition_pdf_or_image(
|
||||
# that task so as routing design changes, those changes are implemented in a single
|
||||
# function.
|
||||
|
||||
if languages is None:
|
||||
languages = ["eng"]
|
||||
|
||||
# init ability to process .heic files
|
||||
register_heif_opener()
|
||||
|
||||
@ -291,6 +298,10 @@ def partition_pdf_or_image(
|
||||
if file is not None:
|
||||
file.seek(0)
|
||||
|
||||
ocr_languages = prepare_languages_for_tesseract(languages)
|
||||
if env_config.OCR_AGENT == OCR_AGENT_PADDLE:
|
||||
ocr_languages = tesseract_to_paddle_language(ocr_languages)
|
||||
|
||||
if strategy == PartitionStrategy.HI_RES:
|
||||
# NOTE(robinson): Catches a UserWarning that occurs when detection is called
|
||||
with warnings.catch_warnings():
|
||||
@ -302,6 +313,7 @@ def partition_pdf_or_image(
|
||||
infer_table_structure=infer_table_structure,
|
||||
include_page_breaks=include_page_breaks,
|
||||
languages=languages,
|
||||
ocr_languages=ocr_languages,
|
||||
metadata_last_modified=metadata_last_modified or last_modification_date,
|
||||
hi_res_model_name=hi_res_model_name,
|
||||
pdf_text_extractable=pdf_text_extractable,
|
||||
@ -333,6 +345,7 @@ def partition_pdf_or_image(
|
||||
file=file,
|
||||
include_page_breaks=include_page_breaks,
|
||||
languages=languages,
|
||||
ocr_languages=ocr_languages,
|
||||
is_image=is_image,
|
||||
metadata_last_modified=metadata_last_modified or last_modification_date,
|
||||
starting_page_number=starting_page_number,
|
||||
@ -500,6 +513,7 @@ def _partition_pdf_or_image_local(
|
||||
infer_table_structure: bool = False,
|
||||
include_page_breaks: bool = False,
|
||||
languages: Optional[list[str]] = None,
|
||||
ocr_languages: Optional[str] = None,
|
||||
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
||||
model_name: Optional[str] = None, # to be deprecated in favor of `hi_res_model_name`
|
||||
hi_res_model_name: Optional[str] = None,
|
||||
@ -529,11 +543,6 @@ def _partition_pdf_or_image_local(
|
||||
process_file_with_pdfminer,
|
||||
)
|
||||
|
||||
if languages is None:
|
||||
languages = ["eng"]
|
||||
|
||||
ocr_languages = prepare_languages_for_tesseract(languages)
|
||||
|
||||
hi_res_model_name = hi_res_model_name or model_name or default_hi_res_model()
|
||||
if pdf_image_dpi is None:
|
||||
pdf_image_dpi = 300 if hi_res_model_name.startswith("chipper") else 200
|
||||
@ -819,7 +828,8 @@ def _partition_pdf_or_image_with_ocr(
|
||||
filename: str = "",
|
||||
file: Optional[bytes | IO[bytes]] = None,
|
||||
include_page_breaks: bool = False,
|
||||
languages: Optional[list[str]] = ["eng"],
|
||||
languages: Optional[list[str]] = None,
|
||||
ocr_languages: Optional[str] = None,
|
||||
is_image: bool = False,
|
||||
metadata_last_modified: Optional[str] = None,
|
||||
starting_page_number: int = 1,
|
||||
@ -838,6 +848,7 @@ def _partition_pdf_or_image_with_ocr(
|
||||
page_elements = _partition_pdf_or_image_with_ocr_from_image(
|
||||
image=image,
|
||||
languages=languages,
|
||||
ocr_languages=ocr_languages,
|
||||
page_number=page_number,
|
||||
include_page_breaks=include_page_breaks,
|
||||
metadata_last_modified=metadata_last_modified,
|
||||
@ -851,6 +862,7 @@ def _partition_pdf_or_image_with_ocr(
|
||||
page_elements = _partition_pdf_or_image_with_ocr_from_image(
|
||||
image=image,
|
||||
languages=languages,
|
||||
ocr_languages=ocr_languages,
|
||||
page_number=page_number,
|
||||
include_page_breaks=include_page_breaks,
|
||||
metadata_last_modified=metadata_last_modified,
|
||||
@ -864,6 +876,7 @@ def _partition_pdf_or_image_with_ocr(
|
||||
def _partition_pdf_or_image_with_ocr_from_image(
|
||||
image: PILImage.Image,
|
||||
languages: Optional[list[str]] = None,
|
||||
ocr_languages: Optional[str] = None,
|
||||
page_number: int = 1,
|
||||
include_page_breaks: bool = False,
|
||||
metadata_last_modified: Optional[str] = None,
|
||||
@ -874,17 +887,13 @@ def _partition_pdf_or_image_with_ocr_from_image(
|
||||
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
|
||||
ocr_agent = OCRAgent.get_agent()
|
||||
ocr_languages = prepare_languages_for_tesseract(languages)
|
||||
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
|
||||
|
||||
# NOTE(christine): `unstructured_pytesseract.image_to_string()` returns sorted text
|
||||
if ocr_agent.is_text_sorted():
|
||||
sort_mode = SORT_MODE_DONT
|
||||
|
||||
ocr_data = ocr_agent.get_layout_elements_from_image(
|
||||
image=image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
ocr_data = ocr_agent.get_layout_elements_from_image(image=image)
|
||||
|
||||
metadata = ElementMetadata(
|
||||
last_modified=metadata_last_modified,
|
||||
|
||||
@ -16,7 +16,7 @@ from unstructured.metrics.table.table_formats import SimpleTableCell
|
||||
from unstructured.partition.pdf_image.analysis.bbox_visualisation import OCRLayoutDrawer
|
||||
from unstructured.partition.pdf_image.pdf_image_utils import pad_element_bboxes, valid_text
|
||||
from unstructured.partition.utils.config import env_config
|
||||
from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCRMode
|
||||
from unstructured.partition.utils.constants import OCRMode
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
@ -200,12 +200,9 @@ def supplement_page_layout_with_ocr(
|
||||
with no text and add text from OCR to each element.
|
||||
"""
|
||||
|
||||
ocr_agent = OCRAgent.get_agent()
|
||||
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
|
||||
if ocr_mode == OCRMode.FULL_PAGE.value:
|
||||
ocr_layout = ocr_agent.get_layout_from_image(
|
||||
image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
ocr_layout = ocr_agent.get_layout_from_image(image)
|
||||
if ocr_drawer:
|
||||
ocr_drawer.add_ocred_page(ocr_layout)
|
||||
page_layout.elements[:] = merge_out_layout_with_ocr_layout(
|
||||
@ -227,10 +224,7 @@ def supplement_page_layout_with_ocr(
|
||||
)
|
||||
# Note(yuming): instead of getting OCR layout, we just need
|
||||
# the text extraced from OCR for individual elements
|
||||
text_from_ocr = ocr_agent.get_text_from_image(
|
||||
cropped_image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
text_from_ocr = ocr_agent.get_text_from_image(cropped_image)
|
||||
element.text = text_from_ocr
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -250,7 +244,6 @@ def supplement_page_layout_with_ocr(
|
||||
elements=cast(List["LayoutElement"], page_layout.elements),
|
||||
image=image,
|
||||
tables_agent=tables.tables_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_agent=ocr_agent,
|
||||
extracted_regions=extracted_regions,
|
||||
)
|
||||
@ -263,8 +256,7 @@ def supplement_element_with_table_extraction(
|
||||
elements: List["LayoutElement"],
|
||||
image: PILImage.Image,
|
||||
tables_agent: "UnstructuredTableTransformerModel",
|
||||
ocr_languages: str = "eng",
|
||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||
ocr_agent,
|
||||
extracted_regions: Optional[List["TextRegion"]] = None,
|
||||
) -> List["LayoutElement"]:
|
||||
"""Supplement the existing layout with table extraction. Any Table elements
|
||||
@ -288,7 +280,6 @@ def supplement_element_with_table_extraction(
|
||||
)
|
||||
table_tokens = get_table_tokens(
|
||||
table_element_image=cropped_image,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_agent=ocr_agent,
|
||||
extracted_regions=extracted_regions,
|
||||
table_element=padded_element,
|
||||
@ -312,17 +303,13 @@ def supplement_element_with_table_extraction(
|
||||
|
||||
def get_table_tokens(
|
||||
table_element_image: PILImage.Image,
|
||||
ocr_languages: str = "eng",
|
||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||
ocr_agent: OCRAgent,
|
||||
extracted_regions: Optional[List["TextRegion"]] = None,
|
||||
table_element: Optional["LayoutElement"] = None,
|
||||
) -> List[dict[str, Any]]:
|
||||
"""Get OCR tokens from either paddleocr or tesseract"""
|
||||
|
||||
ocr_layout = ocr_agent.get_layout_from_image(
|
||||
image=table_element_image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
ocr_layout = ocr_agent.get_layout_from_image(image=table_element_image)
|
||||
table_tokens = []
|
||||
for ocr_region in ocr_layout:
|
||||
table_tokens.append(
|
||||
|
||||
@ -43,10 +43,6 @@ OCR_AGENT_MODULES_WHITELIST = os.getenv(
|
||||
|
||||
UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False)
|
||||
|
||||
# Note(yuming): Default language for paddle OCR
|
||||
# soon will be able to specify the language down through partition() as well
|
||||
DEFAULT_PADDLE_LANG = os.getenv("DEFAULT_PADDLE_LANG", "en")
|
||||
|
||||
# this field is defined by pytesseract/unstructured.pytesseract
|
||||
TESSERACT_TEXT_HEIGHT = "height"
|
||||
|
||||
|
||||
@ -25,17 +25,17 @@ class OCRAgent(ABC):
|
||||
"""Defines the interface for an Optical Character Recognition (OCR) service."""
|
||||
|
||||
@classmethod
|
||||
def get_agent(cls) -> OCRAgent:
|
||||
def get_agent(cls, language: str) -> OCRAgent:
|
||||
"""Get the configured OCRAgent instance.
|
||||
|
||||
The OCR package used by the agent is determined by the `OCR_AGENT` environment variable.
|
||||
"""
|
||||
ocr_agent_cls_qname = cls._get_ocr_agent_cls_qname()
|
||||
return cls.get_instance(ocr_agent_cls_qname)
|
||||
return cls.get_instance(ocr_agent_cls_qname, language)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_instance(ocr_agent_module: str) -> "OCRAgent":
|
||||
def get_instance(ocr_agent_module: str, language: str) -> "OCRAgent":
|
||||
module_name, class_name = ocr_agent_module.rsplit(".", 1)
|
||||
if module_name not in OCR_AGENT_MODULES_WHITELIST:
|
||||
raise ValueError(
|
||||
@ -46,7 +46,7 @@ class OCRAgent(ABC):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
loaded_class = getattr(module, class_name)
|
||||
return loaded_class()
|
||||
return loaded_class(language)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"Failed to get OCRAgent instance: {e}")
|
||||
raise RuntimeError(
|
||||
@ -55,19 +55,15 @@ class OCRAgent(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[LayoutElement]:
|
||||
def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[TextRegion]:
|
||||
def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||
def get_text_from_image(self, image: PILImage.Image) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -7,7 +7,7 @@ from PIL import Image as PILImage
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.logger import logger, trace_logger
|
||||
from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source
|
||||
from unstructured.partition.utils.constants import Source
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
@ -19,10 +19,10 @@ if TYPE_CHECKING:
|
||||
class OCRAgentPaddle(OCRAgent):
|
||||
"""OCR service implementation for PaddleOCR."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = self.load_agent()
|
||||
def __init__(self, language: str = "en"):
|
||||
self.agent = self.load_agent(language)
|
||||
|
||||
def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
|
||||
def load_agent(self, language: str):
|
||||
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
|
||||
|
||||
import paddle
|
||||
@ -59,16 +59,14 @@ class OCRAgentPaddle(OCRAgent):
|
||||
)
|
||||
return paddle_ocr
|
||||
|
||||
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||
def get_text_from_image(self, image: PILImage.Image) -> str:
|
||||
ocr_regions = self.get_layout_from_image(image)
|
||||
return "\n\n".join([r.text for r in ocr_regions])
|
||||
|
||||
def is_text_sorted(self):
|
||||
return False
|
||||
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[TextRegion]:
|
||||
def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]:
|
||||
"""Get the OCR regions from image as a list of text regions with paddle."""
|
||||
|
||||
trace_logger.detail("Processing entire page OCR with paddle...")
|
||||
@ -82,15 +80,10 @@ class OCRAgentPaddle(OCRAgent):
|
||||
return ocr_regions
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[LayoutElement]:
|
||||
def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]:
|
||||
from unstructured.partition.pdf_image.inference_utils import build_layout_element
|
||||
|
||||
ocr_regions = self.get_layout_from_image(
|
||||
image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
ocr_regions = self.get_layout_from_image(image)
|
||||
|
||||
# NOTE(christine): For paddle, there is no difference in `ocr_layout` and `ocr_text` in
|
||||
# terms of grouping because we get ocr_text from `ocr_layout, so the first two grouping
|
||||
|
||||
@ -33,22 +33,23 @@ if "OMP_THREAD_LIMIT" not in os.environ:
|
||||
class OCRAgentTesseract(OCRAgent):
|
||||
"""OCR service implementation for Tesseract."""
|
||||
|
||||
def __init__(self, language: str = "eng"):
|
||||
self.language = language
|
||||
|
||||
def is_text_sorted(self):
|
||||
return True
|
||||
|
||||
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||
return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages)
|
||||
def get_text_from_image(self, image: PILImage.Image) -> str:
|
||||
return unstructured_pytesseract.image_to_string(np.array(image), lang=self.language)
|
||||
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> List[TextRegion]:
|
||||
def get_layout_from_image(self, image: PILImage.Image) -> List[TextRegion]:
|
||||
"""Get the OCR regions from image as a list of text regions with tesseract."""
|
||||
|
||||
trace_logger.detail("Processing entire page OCR with tesseract...")
|
||||
zoom = 1
|
||||
ocr_df: pd.DataFrame = unstructured_pytesseract.image_to_data(
|
||||
np.array(image),
|
||||
lang=ocr_languages,
|
||||
lang=self.language,
|
||||
output_type=Output.DATAFRAME,
|
||||
)
|
||||
ocr_df = ocr_df.dropna()
|
||||
@ -77,7 +78,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
)
|
||||
ocr_df = unstructured_pytesseract.image_to_data(
|
||||
np.array(zoom_image(image, zoom)),
|
||||
lang=ocr_languages,
|
||||
lang=self.language,
|
||||
output_type=Output.DATAFRAME,
|
||||
)
|
||||
ocr_df = ocr_df.dropna()
|
||||
@ -87,17 +88,12 @@ class OCRAgentTesseract(OCRAgent):
|
||||
return ocr_regions
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> List["LayoutElement"]:
|
||||
def get_layout_elements_from_image(self, image: PILImage.Image) -> List["LayoutElement"]:
|
||||
from unstructured.partition.pdf_image.inference_utils import (
|
||||
build_layout_elements_from_ocr_regions,
|
||||
)
|
||||
|
||||
ocr_regions = self.get_layout_from_image(
|
||||
image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
ocr_regions = self.get_layout_from_image(image)
|
||||
|
||||
# NOTE(christine): For tesseract, the ocr_text returned by
|
||||
# `unstructured_pytesseract.image_to_string()` doesn't contain bounding box data but is
|
||||
@ -106,10 +102,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
# grouped. Therefore, we need to first group the `ocr_layout` by `ocr_text` and then merge
|
||||
# the text regions in each group to create a list of layout elements.
|
||||
|
||||
ocr_text = self.get_text_from_image(
|
||||
image,
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
ocr_text = self.get_text_from_image(image)
|
||||
|
||||
return build_layout_elements_from_ocr_regions(
|
||||
ocr_regions=ocr_regions,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user