diff --git a/CHANGELOG.md b/CHANGELOG.md index f38b81744..7ed9414e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ### Features +* **Add support for specifying OCR language to `partition_pdf()`.** Extend language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`. * **Add AstraDB source connector** Adds support for ingesting documents from AstraDB. ### Fixes diff --git a/test_unstructured/partition/pdf_image/test_ocr.py b/test_unstructured/partition/pdf_image/test_ocr.py index 175682156..e07fb23d3 100644 --- a/test_unstructured/partition/pdf_image/test_ocr.py +++ b/test_unstructured/partition/pdf_image/test_ocr.py @@ -21,6 +21,7 @@ from unstructured.partition.utils.constants import ( Source, ) from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision +from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle from unstructured.partition.utils.ocr_models.tesseract_ocr import ( OCRAgentTesseract, @@ -85,10 +86,7 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch): image = Image.new("RGB", (100, 100)) ocr_agent = OCRAgentTesseract() - ocr_layout = ocr_agent.get_layout_from_image( - image, - ocr_languages="eng", - ) + ocr_layout = ocr_agent.get_layout_from_image(image) expected_layout = [ TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_TESSERACT), @@ -128,7 +126,7 @@ def mock_ocr(*args, **kwargs): ] -def monkeypatch_load_agent(language: str): +def monkeypatch_load_agent(*args): class MockAgent: def __init__(self): self.ocr = mock_ocr @@ -145,10 +143,7 @@ def test_get_ocr_layout_from_image_paddle(monkeypatch): image = Image.new("RGB", (100, 100)) - ocr_layout = OCRAgentPaddle().get_layout_from_image( - image, - ocr_languages="eng", - ) + ocr_layout = OCRAgentPaddle().get_layout_from_image(image) expected_layout = [ TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_PADDLE), @@ -168,10 +163,7 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch): image = Image.new("RGB", (100, 100)) ocr_agent = OCRAgentTesseract() - ocr_text = ocr_agent.get_text_from_image( - image, - ocr_languages="eng", - ) + ocr_text = ocr_agent.get_text_from_image(image) assert ocr_text == "Hello World" @@ -186,10 +178,7 @@ def test_get_ocr_text_from_image_paddle(monkeypatch): image = Image.new("RGB", (100, 100)) ocr_agent = OCRAgentPaddle() - ocr_text = ocr_agent.get_text_from_image( - image, - ocr_languages="eng", - ) + ocr_text = ocr_agent.get_text_from_image(image) assert ocr_text == "Hello\n\nWorld\n\n!" @@ -251,7 +240,7 @@ def test_get_ocr_from_image_google_vision(google_vision_client): image = Image.new("RGB", (100, 100)) ocr_agent = google_vision_client - ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng") + ocr_text = ocr_agent.get_text_from_image(image) assert ocr_text == "Hello World!" @@ -428,7 +417,8 @@ def mock_ocr_layout(): def test_get_table_tokens(mock_ocr_layout): with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout): - table_tokens = ocr.get_table_tokens(table_element_image=None) + ocr_agent = OCRAgent.get_agent(language="eng") + table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent) expected_tokens = [ { "bbox": [15, 25, 35, 45], diff --git a/test_unstructured/partition/test_lang.py b/test_unstructured/partition/test_lang.py index e8bb5fedf..f1d743a8b 100644 --- a/test_unstructured/partition/test_lang.py +++ b/test_unstructured/partition/test_lang.py @@ -21,6 +21,7 @@ from unstructured.partition.lang import ( check_language_args, detect_languages, prepare_languages_for_tesseract, + tesseract_to_paddle_language, ) DIRECTORY = pathlib.Path(__file__).parent.resolve() @@ -84,6 +85,39 @@ def test_prepare_languages_for_tesseract_no_valid_languages(caplog): assert "Failed to find any valid standard language code from languages" in caplog.text +@pytest.mark.parametrize( + ("tesseract_lang", "expected_lang"), + [ + ("eng", "en"), + ("chi_sim", "ch"), + ("chi_tra", "chinese_cht"), + ("deu", "german"), + ("jpn", "japan"), + ("kor", "korean"), + ], +) +def test_tesseract_to_paddle_language_valid_codes(tesseract_lang, expected_lang): + assert expected_lang == tesseract_to_paddle_language(tesseract_lang) + + +def test_tesseract_to_paddle_language_invalid_codes(caplog): + tesseract_lang = "unsupported_lang" + assert tesseract_to_paddle_language(tesseract_lang) == "en" + assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text + + +@pytest.mark.parametrize( + ("tesseract_lang", "expected_lang"), + [ + ("ENG", "en"), + ("Fra", "fr"), + ("DEU", "german"), + ], +) +def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang, expected_lang): + assert expected_lang == tesseract_to_paddle_language(tesseract_lang) + + def test_detect_languages_english_auto(): text = "This is a short sentence." assert detect_languages(text) == ["eng"] diff --git a/test_unstructured/partition/utils/ocr_models/test_ocr_interface.py b/test_unstructured/partition/utils/ocr_models/test_ocr_interface.py index 381e9546e..28623372a 100644 --- a/test_unstructured/partition/utils/ocr_models/test_ocr_interface.py +++ b/test_unstructured/partition/utils/ocr_models/test_ocr_interface.py @@ -35,10 +35,10 @@ class DescribeOCRAgent: _get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT get_instance_.return_value = ocr_agent_ - ocr_agent = OCRAgent.get_agent() + ocr_agent = OCRAgent.get_agent(language="eng") _get_ocr_agent_cls_qname_.assert_called_once_with() - get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT) + get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng") assert ocr_agent is ocr_agent_ def but_it_raises_when_the_requested_agent_is_not_whitelisted( @@ -46,7 +46,7 @@ class DescribeOCRAgent: ): _get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname" with pytest.raises(ValueError, match="must be set to a whitelisted module"): - OCRAgent.get_agent() + OCRAgent.get_agent(language="eng") @pytest.mark.parametrize("exception_cls", [ImportError, AttributeError]) def and_it_raises_when_the_requested_agent_cannot_be_loaded( @@ -57,7 +57,7 @@ class DescribeOCRAgent: "unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module", side_effect=exception_cls, ), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"): - OCRAgent.get_agent() + OCRAgent.get_agent(language="eng") @pytest.mark.parametrize( ("OCR_AGENT", "expected_value"), diff --git a/unstructured/metrics/table_structure.py b/unstructured/metrics/table_structure.py index 53f9171ad..e139f47ca 100644 --- a/unstructured/metrics/table_structure.py +++ b/unstructured/metrics/table_structure.py @@ -4,6 +4,7 @@ from PIL import Image from unstructured.partition.pdf import convert_pdf_to_images from unstructured.partition.pdf_image.ocr import get_table_tokens +from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.utils import requires_dependencies @@ -21,8 +22,10 @@ def image_or_pdf_to_dataframe(filename: str) -> pd.DataFrame: else: image = Image.open(filename).convert("RGB") + ocr_agent = OCRAgent.get_agent(language="eng") + return tables_agent.run_prediction( - image, ocr_tokens=get_table_tokens(image), result_format="dataframe" + image, ocr_tokens=get_table_tokens(image, ocr_agent), result_format="dataframe" ) diff --git a/unstructured/partition/image.py b/unstructured/partition/image.py index 8fa98db38..a9a9ea963 100644 --- a/unstructured/partition/image.py +++ b/unstructured/partition/image.py @@ -98,7 +98,7 @@ def partition_image( """ exactly_one(filename=filename, file=file) - languages = check_language_args(languages or [], ocr_languages) or ["eng"] + languages = check_language_args(languages or [], ocr_languages) return partition_pdf_or_image( filename=filename, diff --git a/unstructured/partition/lang.py b/unstructured/partition/lang.py index 18fc6c05d..d91d64882 100644 --- a/unstructured/partition/lang.py +++ b/unstructured/partition/lang.py @@ -144,6 +144,63 @@ PYTESSERACT_LANG_CODES = [ "yor", ] +PYTESSERACT_TO_PADDLE_LANG_CODE_MAP = { + "afr": "af", # Afrikaans + "ara": "ar", # Arabic + "aze": "az", # Azerbaijani + "bel": "be", # Belarusian + "bos": "bs", # Bosnian + "bul": "bg", # Bulgarian + "ces": "cs", # Czech + "chi_sim": "ch", # Simplified Chinese + "chi_tra": "chinese_cht", # Traditional Chinese + "cym": "cy", # Welsh + "dan": "da", # Danish + "deu": "german", # German + "eng": "en", # English + "est": "et", # Estonian + "fas": "fa", # Persian + "fra": "fr", # French + "gle": "ga", # Irish + "hin": "hi", # Hindi + "hrv": "hr", # Croatian + "hun": "hu", # Hungarian + "ind": "id", # Indonesian + "isl": "is", # Icelandic + "ita": "it", # Italian + "jpn": "japan", # Japanese + "kor": "korean", # Korean + "kmr": "ku", # Kurdish + "lat": "rs_latin", # Latin + "lav": "lv", # Latvian + "lit": "lt", # Lithuanian + "mar": "mr", # Marathi + "mlt": "mt", # Maltese + "msa": "ms", # Malay + "nep": "ne", # Nepali + "nld": "nl", # Dutch + "nor": "no", # Norwegian + "pol": "pl", # Polish + "por": "pt", # Portuguese + "ron": "ro", # Romanian + "rus": "ru", # Russian + "slk": "sk", # Slovak + "slv": "sl", # Slovenian + "spa": "es", # Spanish + "sqi": "sq", # Albanian + "srp": "rs_cyrillic", # Serbian + "swa": "sw", # Swahili + "swe": "sv", # Swedish + "tam": "ta", # Tamil + "tel": "te", # Telugu + "tur": "tr", # Turkish + "uig": "ug", # Uyghur + "ukr": "uk", # Ukrainian + "urd": "ur", # Urdu + "uzb": "uz", # Uzbek + "vie": "vi", # Vietnamese +} + def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> str: """ @@ -169,6 +226,25 @@ def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> return TESSERACT_LANGUAGES_SPLITTER.join(converted_languages) +def tesseract_to_paddle_language(tesseract_language: str) -> str: + """ + Convert TesseractOCR language code to PaddleOCR language code. + + :param tesseract_language: str, language code used in TesseractOCR + :return: str, corresponding language code for PaddleOCR or None if not found + """ + + lang = PYTESSERACT_TO_PADDLE_LANG_CODE_MAP.get(tesseract_language.lower()) + if not lang: + logger.warning( + f"{tesseract_language} is not a language code supported by PaddleOCR, " + f"proceeding with `en` instead." + ) + return "en" + + return lang + + def check_language_args(languages: list[str], ocr_languages: Optional[str]) -> Optional[list[str]]: """Handle users defining both `ocr_languages` and `languages`, giving preference to `languages` and converting `ocr_languages` if needed, but defaulting to `None. diff --git a/unstructured/partition/pdf.py b/unstructured/partition/pdf.py index b808b4e4a..e52fe0f8c 100644 --- a/unstructured/partition/pdf.py +++ b/unstructured/partition/pdf.py @@ -45,7 +45,11 @@ from unstructured.partition.common import ( ocr_data_to_elements, spooled_to_bytes_io_if_needed, ) -from unstructured.partition.lang import check_language_args, prepare_languages_for_tesseract +from unstructured.partition.lang import ( + check_language_args, + prepare_languages_for_tesseract, + tesseract_to_paddle_language, +) from unstructured.partition.pdf_image.analysis.bbox_visualisation import ( AnalysisDrawer, FinalLayoutDrawer, @@ -77,6 +81,7 @@ from unstructured.partition.strategies import determine_pdf_or_image_strategy, v from unstructured.partition.text import element_from_text from unstructured.partition.utils.config import env_config from unstructured.partition.utils.constants import ( + OCR_AGENT_PADDLE, SORT_MODE_BASIC, SORT_MODE_DONT, SORT_MODE_XY_CUT, @@ -197,7 +202,7 @@ def partition_pdf( exactly_one(filename=filename, file=file) - languages = check_language_args(languages or [], ocr_languages) or ["eng"] + languages = check_language_args(languages or [], ocr_languages) return partition_pdf_or_image( filename=filename, @@ -227,7 +232,6 @@ def partition_pdf_or_image( include_page_breaks: bool = False, strategy: str = PartitionStrategy.AUTO, infer_table_structure: bool = False, - ocr_languages: Optional[str] = None, languages: Optional[list[str]] = None, metadata_last_modified: Optional[str] = None, hi_res_model_name: Optional[str] = None, @@ -247,6 +251,9 @@ def partition_pdf_or_image( # that task so as routing design changes, those changes are implemented in a single # function. + if languages is None: + languages = ["eng"] + # init ability to process .heic files register_heif_opener() @@ -291,6 +298,10 @@ def partition_pdf_or_image( if file is not None: file.seek(0) + ocr_languages = prepare_languages_for_tesseract(languages) + if env_config.OCR_AGENT == OCR_AGENT_PADDLE: + ocr_languages = tesseract_to_paddle_language(ocr_languages) + if strategy == PartitionStrategy.HI_RES: # NOTE(robinson): Catches a UserWarning that occurs when detection is called with warnings.catch_warnings(): @@ -302,6 +313,7 @@ def partition_pdf_or_image( infer_table_structure=infer_table_structure, include_page_breaks=include_page_breaks, languages=languages, + ocr_languages=ocr_languages, metadata_last_modified=metadata_last_modified or last_modification_date, hi_res_model_name=hi_res_model_name, pdf_text_extractable=pdf_text_extractable, @@ -333,6 +345,7 @@ def partition_pdf_or_image( file=file, include_page_breaks=include_page_breaks, languages=languages, + ocr_languages=ocr_languages, is_image=is_image, metadata_last_modified=metadata_last_modified or last_modification_date, starting_page_number=starting_page_number, @@ -500,6 +513,7 @@ def _partition_pdf_or_image_local( infer_table_structure: bool = False, include_page_breaks: bool = False, languages: Optional[list[str]] = None, + ocr_languages: Optional[str] = None, ocr_mode: str = OCRMode.FULL_PAGE.value, model_name: Optional[str] = None, # to be deprecated in favor of `hi_res_model_name` hi_res_model_name: Optional[str] = None, @@ -529,11 +543,6 @@ def _partition_pdf_or_image_local( process_file_with_pdfminer, ) - if languages is None: - languages = ["eng"] - - ocr_languages = prepare_languages_for_tesseract(languages) - hi_res_model_name = hi_res_model_name or model_name or default_hi_res_model() if pdf_image_dpi is None: pdf_image_dpi = 300 if hi_res_model_name.startswith("chipper") else 200 @@ -819,7 +828,8 @@ def _partition_pdf_or_image_with_ocr( filename: str = "", file: Optional[bytes | IO[bytes]] = None, include_page_breaks: bool = False, - languages: Optional[list[str]] = ["eng"], + languages: Optional[list[str]] = None, + ocr_languages: Optional[str] = None, is_image: bool = False, metadata_last_modified: Optional[str] = None, starting_page_number: int = 1, @@ -838,6 +848,7 @@ def _partition_pdf_or_image_with_ocr( page_elements = _partition_pdf_or_image_with_ocr_from_image( image=image, languages=languages, + ocr_languages=ocr_languages, page_number=page_number, include_page_breaks=include_page_breaks, metadata_last_modified=metadata_last_modified, @@ -851,6 +862,7 @@ def _partition_pdf_or_image_with_ocr( page_elements = _partition_pdf_or_image_with_ocr_from_image( image=image, languages=languages, + ocr_languages=ocr_languages, page_number=page_number, include_page_breaks=include_page_breaks, metadata_last_modified=metadata_last_modified, @@ -864,6 +876,7 @@ def _partition_pdf_or_image_with_ocr( def _partition_pdf_or_image_with_ocr_from_image( image: PILImage.Image, languages: Optional[list[str]] = None, + ocr_languages: Optional[str] = None, page_number: int = 1, include_page_breaks: bool = False, metadata_last_modified: Optional[str] = None, @@ -874,17 +887,13 @@ def _partition_pdf_or_image_with_ocr_from_image( from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent - ocr_agent = OCRAgent.get_agent() - ocr_languages = prepare_languages_for_tesseract(languages) + ocr_agent = OCRAgent.get_agent(language=ocr_languages) # NOTE(christine): `unstructured_pytesseract.image_to_string()` returns sorted text if ocr_agent.is_text_sorted(): sort_mode = SORT_MODE_DONT - ocr_data = ocr_agent.get_layout_elements_from_image( - image=image, - ocr_languages=ocr_languages, - ) + ocr_data = ocr_agent.get_layout_elements_from_image(image=image) metadata = ElementMetadata( last_modified=metadata_last_modified, diff --git a/unstructured/partition/pdf_image/ocr.py b/unstructured/partition/pdf_image/ocr.py index b64955159..92a46c0e0 100644 --- a/unstructured/partition/pdf_image/ocr.py +++ b/unstructured/partition/pdf_image/ocr.py @@ -16,7 +16,7 @@ from unstructured.metrics.table.table_formats import SimpleTableCell from unstructured.partition.pdf_image.analysis.bbox_visualisation import OCRLayoutDrawer from unstructured.partition.pdf_image.pdf_image_utils import pad_element_bboxes, valid_text from unstructured.partition.utils.config import env_config -from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCRMode +from unstructured.partition.utils.constants import OCRMode from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.utils import requires_dependencies @@ -200,12 +200,9 @@ def supplement_page_layout_with_ocr( with no text and add text from OCR to each element. """ - ocr_agent = OCRAgent.get_agent() + ocr_agent = OCRAgent.get_agent(language=ocr_languages) if ocr_mode == OCRMode.FULL_PAGE.value: - ocr_layout = ocr_agent.get_layout_from_image( - image, - ocr_languages=ocr_languages, - ) + ocr_layout = ocr_agent.get_layout_from_image(image) if ocr_drawer: ocr_drawer.add_ocred_page(ocr_layout) page_layout.elements[:] = merge_out_layout_with_ocr_layout( @@ -227,10 +224,7 @@ def supplement_page_layout_with_ocr( ) # Note(yuming): instead of getting OCR layout, we just need # the text extraced from OCR for individual elements - text_from_ocr = ocr_agent.get_text_from_image( - cropped_image, - ocr_languages=ocr_languages, - ) + text_from_ocr = ocr_agent.get_text_from_image(cropped_image) element.text = text_from_ocr else: raise ValueError( @@ -250,7 +244,6 @@ def supplement_page_layout_with_ocr( elements=cast(List["LayoutElement"], page_layout.elements), image=image, tables_agent=tables.tables_agent, - ocr_languages=ocr_languages, ocr_agent=ocr_agent, extracted_regions=extracted_regions, ) @@ -263,8 +256,7 @@ def supplement_element_with_table_extraction( elements: List["LayoutElement"], image: PILImage.Image, tables_agent: "UnstructuredTableTransformerModel", - ocr_languages: str = "eng", - ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT), + ocr_agent, extracted_regions: Optional[List["TextRegion"]] = None, ) -> List["LayoutElement"]: """Supplement the existing layout with table extraction. Any Table elements @@ -288,7 +280,6 @@ def supplement_element_with_table_extraction( ) table_tokens = get_table_tokens( table_element_image=cropped_image, - ocr_languages=ocr_languages, ocr_agent=ocr_agent, extracted_regions=extracted_regions, table_element=padded_element, @@ -312,17 +303,13 @@ def supplement_element_with_table_extraction( def get_table_tokens( table_element_image: PILImage.Image, - ocr_languages: str = "eng", - ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT), + ocr_agent: OCRAgent, extracted_regions: Optional[List["TextRegion"]] = None, table_element: Optional["LayoutElement"] = None, ) -> List[dict[str, Any]]: """Get OCR tokens from either paddleocr or tesseract""" - ocr_layout = ocr_agent.get_layout_from_image( - image=table_element_image, - ocr_languages=ocr_languages, - ) + ocr_layout = ocr_agent.get_layout_from_image(image=table_element_image) table_tokens = [] for ocr_region in ocr_layout: table_tokens.append( diff --git a/unstructured/partition/utils/constants.py b/unstructured/partition/utils/constants.py index c1864a9e5..9d802080c 100644 --- a/unstructured/partition/utils/constants.py +++ b/unstructured/partition/utils/constants.py @@ -43,10 +43,6 @@ OCR_AGENT_MODULES_WHITELIST = os.getenv( UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False) -# Note(yuming): Default language for paddle OCR -# soon will be able to specify the language down through partition() as well -DEFAULT_PADDLE_LANG = os.getenv("DEFAULT_PADDLE_LANG", "en") - # this field is defined by pytesseract/unstructured.pytesseract TESSERACT_TEXT_HEIGHT = "height" diff --git a/unstructured/partition/utils/ocr_models/ocr_interface.py b/unstructured/partition/utils/ocr_models/ocr_interface.py index 3efb6986a..6808d5aad 100644 --- a/unstructured/partition/utils/ocr_models/ocr_interface.py +++ b/unstructured/partition/utils/ocr_models/ocr_interface.py @@ -25,17 +25,17 @@ class OCRAgent(ABC): """Defines the interface for an Optical Character Recognition (OCR) service.""" @classmethod - def get_agent(cls) -> OCRAgent: + def get_agent(cls, language: str) -> OCRAgent: """Get the configured OCRAgent instance. The OCR package used by the agent is determined by the `OCR_AGENT` environment variable. """ ocr_agent_cls_qname = cls._get_ocr_agent_cls_qname() - return cls.get_instance(ocr_agent_cls_qname) + return cls.get_instance(ocr_agent_cls_qname, language) @staticmethod @functools.lru_cache(maxsize=None) - def get_instance(ocr_agent_module: str) -> "OCRAgent": + def get_instance(ocr_agent_module: str, language: str) -> "OCRAgent": module_name, class_name = ocr_agent_module.rsplit(".", 1) if module_name not in OCR_AGENT_MODULES_WHITELIST: raise ValueError( @@ -46,7 +46,7 @@ class OCRAgent(ABC): try: module = importlib.import_module(module_name) loaded_class = getattr(module, class_name) - return loaded_class() + return loaded_class(language) except (ImportError, AttributeError) as e: logger.error(f"Failed to get OCRAgent instance: {e}") raise RuntimeError( @@ -55,19 +55,15 @@ class OCRAgent(ABC): ) @abstractmethod - def get_layout_elements_from_image( - self, image: PILImage.Image, ocr_languages: str = "eng" - ) -> list[LayoutElement]: + def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]: pass @abstractmethod - def get_layout_from_image( - self, image: PILImage.Image, ocr_languages: str = "eng" - ) -> list[TextRegion]: + def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]: pass @abstractmethod - def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str: + def get_text_from_image(self, image: PILImage.Image) -> str: pass @abstractmethod diff --git a/unstructured/partition/utils/ocr_models/paddle_ocr.py b/unstructured/partition/utils/ocr_models/paddle_ocr.py index def01a99e..7e57a1f8a 100644 --- a/unstructured/partition/utils/ocr_models/paddle_ocr.py +++ b/unstructured/partition/utils/ocr_models/paddle_ocr.py @@ -7,7 +7,7 @@ from PIL import Image as PILImage from unstructured.documents.elements import ElementType from unstructured.logger import logger, trace_logger -from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source +from unstructured.partition.utils.constants import Source from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.utils import requires_dependencies @@ -19,10 +19,10 @@ if TYPE_CHECKING: class OCRAgentPaddle(OCRAgent): """OCR service implementation for PaddleOCR.""" - def __init__(self): - self.agent = self.load_agent() + def __init__(self, language: str = "en"): + self.agent = self.load_agent(language) - def load_agent(self, language: str = DEFAULT_PADDLE_LANG): + def load_agent(self, language: str): """Loads the PaddleOCR agent as a global variable to ensure that we only load it once.""" import paddle @@ -59,16 +59,14 @@ class OCRAgentPaddle(OCRAgent): ) return paddle_ocr - def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str: + def get_text_from_image(self, image: PILImage.Image) -> str: ocr_regions = self.get_layout_from_image(image) return "\n\n".join([r.text for r in ocr_regions]) def is_text_sorted(self): return False - def get_layout_from_image( - self, image: PILImage.Image, ocr_languages: str = "eng" - ) -> list[TextRegion]: + def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]: """Get the OCR regions from image as a list of text regions with paddle.""" trace_logger.detail("Processing entire page OCR with paddle...") @@ -82,15 +80,10 @@ class OCRAgentPaddle(OCRAgent): return ocr_regions @requires_dependencies("unstructured_inference") - def get_layout_elements_from_image( - self, image: PILImage.Image, ocr_languages: str = "eng" - ) -> list[LayoutElement]: + def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]: from unstructured.partition.pdf_image.inference_utils import build_layout_element - ocr_regions = self.get_layout_from_image( - image, - ocr_languages=ocr_languages, - ) + ocr_regions = self.get_layout_from_image(image) # NOTE(christine): For paddle, there is no difference in `ocr_layout` and `ocr_text` in # terms of grouping because we get ocr_text from `ocr_layout, so the first two grouping diff --git a/unstructured/partition/utils/ocr_models/tesseract_ocr.py b/unstructured/partition/utils/ocr_models/tesseract_ocr.py index 7f2c87424..46eb8a0cb 100644 --- a/unstructured/partition/utils/ocr_models/tesseract_ocr.py +++ b/unstructured/partition/utils/ocr_models/tesseract_ocr.py @@ -33,22 +33,23 @@ if "OMP_THREAD_LIMIT" not in os.environ: class OCRAgentTesseract(OCRAgent): """OCR service implementation for Tesseract.""" + def __init__(self, language: str = "eng"): + self.language = language + def is_text_sorted(self): return True - def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str: - return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages) + def get_text_from_image(self, image: PILImage.Image) -> str: + return unstructured_pytesseract.image_to_string(np.array(image), lang=self.language) - def get_layout_from_image( - self, image: PILImage.Image, ocr_languages: str = "eng" - ) -> List[TextRegion]: + def get_layout_from_image(self, image: PILImage.Image) -> List[TextRegion]: """Get the OCR regions from image as a list of text regions with tesseract.""" trace_logger.detail("Processing entire page OCR with tesseract...") zoom = 1 ocr_df: pd.DataFrame = unstructured_pytesseract.image_to_data( np.array(image), - lang=ocr_languages, + lang=self.language, output_type=Output.DATAFRAME, ) ocr_df = ocr_df.dropna() @@ -77,7 +78,7 @@ class OCRAgentTesseract(OCRAgent): ) ocr_df = unstructured_pytesseract.image_to_data( np.array(zoom_image(image, zoom)), - lang=ocr_languages, + lang=self.language, output_type=Output.DATAFRAME, ) ocr_df = ocr_df.dropna() @@ -87,17 +88,12 @@ class OCRAgentTesseract(OCRAgent): return ocr_regions @requires_dependencies("unstructured_inference") - def get_layout_elements_from_image( - self, image: PILImage.Image, ocr_languages: str = "eng" - ) -> List["LayoutElement"]: + def get_layout_elements_from_image(self, image: PILImage.Image) -> List["LayoutElement"]: from unstructured.partition.pdf_image.inference_utils import ( build_layout_elements_from_ocr_regions, ) - ocr_regions = self.get_layout_from_image( - image, - ocr_languages=ocr_languages, - ) + ocr_regions = self.get_layout_from_image(image) # NOTE(christine): For tesseract, the ocr_text returned by # `unstructured_pytesseract.image_to_string()` doesn't contain bounding box data but is @@ -106,10 +102,7 @@ class OCRAgentTesseract(OCRAgent): # grouped. Therefore, we need to first group the `ocr_layout` by `ocr_text` and then merge # the text regions in each group to create a list of layout elements. - ocr_text = self.get_text_from_image( - image, - ocr_languages=ocr_languages, - ) + ocr_text = self.get_text_from_image(image) return build_layout_elements_from_ocr_regions( ocr_regions=ocr_regions,