rfctr: improve typing in OCR modules (#2893)

**Summary**
In preparation for using OCR for partitioners other than PDF, clean up
typing in the OCR module.
This commit is contained in:
Steve Canny 2024-04-15 20:55:35 -07:00 committed by GitHub
parent cb1e91058e
commit f752849c41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 60 additions and 72 deletions

View File

@ -1,4 +1,4 @@
## 0.13.3-dev5 ## 0.13.3-dev6
### Enhancements ### Enhancements

View File

@ -1 +1 @@
__version__ = "0.13.3-dev5" # pragma: no cover __version__ = "0.13.3-dev6" # pragma: no cover

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import os import os
import tempfile import tempfile
from typing import TYPE_CHECKING, BinaryIO, Dict, List, Optional, Union, cast from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
import pdf2image import pdf2image
@ -39,7 +41,7 @@ if "OMP_THREAD_LIMIT" not in os.environ:
def process_data_with_ocr( def process_data_with_ocr(
data: Union[bytes, BinaryIO], data: bytes | IO[bytes],
out_layout: "DocumentLayout", out_layout: "DocumentLayout",
extracted_layout: List[List["TextRegion"]], extracted_layout: List[List["TextRegion"]],
is_image: bool = False, is_image: bool = False,
@ -76,7 +78,8 @@ def process_data_with_ocr(
DocumentLayout: The merged layout information obtained after OCR processing. DocumentLayout: The merged layout information obtained after OCR processing.
""" """
with tempfile.NamedTemporaryFile() as tmp_file: with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(data.read() if hasattr(data, "read") else data) data_bytes = data if isinstance(data, bytes) else data.read()
tmp_file.write(data_bytes)
tmp_file.flush() tmp_file.flush()
merged_layouts = process_file_with_ocr( merged_layouts = process_file_with_ocr(
filename=tmp_file.name, filename=tmp_file.name,
@ -131,7 +134,7 @@ def process_file_with_ocr(
from unstructured_inference.inference.layout import DocumentLayout from unstructured_inference.inference.layout import DocumentLayout
merged_page_layouts = [] merged_page_layouts: list[PageLayout] = []
try: try:
if is_image: if is_image:
with PILImage.open(filename) as images: with PILImage.open(filename) as images:
@ -182,7 +185,7 @@ def process_file_with_ocr(
@requires_dependencies("unstructured_inference") @requires_dependencies("unstructured_inference")
def supplement_page_layout_with_ocr( def supplement_page_layout_with_ocr(
page_layout: "PageLayout", page_layout: "PageLayout",
image: PILImage, image: PILImage.Image,
infer_table_structure: bool = False, infer_table_structure: bool = False,
ocr_languages: str = "eng", ocr_languages: str = "eng",
ocr_mode: str = OCRMode.FULL_PAGE.value, ocr_mode: str = OCRMode.FULL_PAGE.value,
@ -254,7 +257,7 @@ def supplement_page_layout_with_ocr(
def supplement_element_with_table_extraction( def supplement_element_with_table_extraction(
elements: List["LayoutElement"], elements: List["LayoutElement"],
image: PILImage, image: PILImage.Image,
tables_agent: "UnstructuredTableTransformerModel", tables_agent: "UnstructuredTableTransformerModel",
ocr_languages: str = "eng", ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT), ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
@ -289,12 +292,12 @@ def supplement_element_with_table_extraction(
def get_table_tokens( def get_table_tokens(
table_element_image: PILImage, table_element_image: PILImage.Image,
ocr_languages: str = "eng", ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT), ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
extracted_regions: Optional[List["TextRegion"]] = None, extracted_regions: Optional[List["TextRegion"]] = None,
table_element: Optional["LayoutElement"] = None, table_element: Optional["LayoutElement"] = None,
) -> List[Dict]: ) -> List[dict[str, Any]]:
"""Get OCR tokens from either paddleocr or tesseract""" """Get OCR tokens from either paddleocr or tesseract"""
ocr_layout = ocr_agent.get_layout_from_image( ocr_layout = ocr_agent.get_layout_from_image(
@ -417,7 +420,7 @@ def supplement_layout_with_ocr_elements(
build_layout_elements_from_ocr_regions, build_layout_elements_from_ocr_regions,
) )
ocr_regions_to_remove = [] ocr_regions_to_remove: list[TextRegion] = []
for ocr_region in ocr_layout: for ocr_region in ocr_layout:
for el in layout: for el in layout:
ocr_region_is_subregion_of_out_el = ocr_region.bbox.is_almost_subregion_of( ocr_region_is_subregion_of_out_el = ocr_region.bbox.is_almost_subregion_of(

View File

@ -1,12 +1,14 @@
from __future__ import annotations
import functools import functools
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List from typing import TYPE_CHECKING
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL import PILImage from PIL import Image as PILImage
from unstructured_inference.inference.elements import TextRegion from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import ( from unstructured_inference.inference.layoutelement import (
LayoutElement, LayoutElement,
@ -14,31 +16,26 @@ if TYPE_CHECKING:
class OCRAgent(ABC): class OCRAgent(ABC):
def __init__(self): """Defines the interface for an Optical Character Recognition (OCR) service."""
self.agent = self.load_agent()
@abstractmethod
def load_agent(self, language: str) -> Any:
pass
@abstractmethod @abstractmethod
def is_text_sorted(self) -> bool: def is_text_sorted(self) -> bool:
pass pass
@abstractmethod @abstractmethod
def get_text_from_image(self, image: "PILImage", ocr_languages: str = "eng") -> str: def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
pass pass
@abstractmethod @abstractmethod
def get_layout_from_image( def get_layout_from_image(
self, image: "PILImage", ocr_languages: str = "eng" self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["TextRegion"]: ) -> list[TextRegion]:
pass pass
@abstractmethod @abstractmethod
def get_layout_elements_from_image( def get_layout_elements_from_image(
self, image: "PILImage", ocr_languages: str = "eng" self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["LayoutElement"]: ) -> list[LayoutElement]:
pass pass
@staticmethod @staticmethod
@ -51,6 +48,6 @@ class OCRAgent(ABC):
return loaded_class() return loaded_class()
else: else:
raise ValueError( raise ValueError(
f"Environment variable OCR_AGENT module name {module_name}", f"Environment variable OCR_AGENT module name {module_name}, must be set to a"
f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.", f" whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
) )

View File

@ -1,14 +1,13 @@
from typing import TYPE_CHECKING, List from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
from PIL import Image as PILImage from PIL import Image as PILImage
from unstructured.documents.elements import ElementType from unstructured.documents.elements import ElementType
from unstructured.logger import logger from unstructured.logger import logger
from unstructured.partition.utils.constants import ( from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source
DEFAULT_PADDLE_LANG,
Source,
)
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
@ -18,12 +17,17 @@ if TYPE_CHECKING:
class OCRAgentPaddle(OCRAgent): class OCRAgentPaddle(OCRAgent):
"""OCR service implementation for PaddleOCR."""
def __init__(self):
self.agent = self.load_agent()
def load_agent(self, language: str = DEFAULT_PADDLE_LANG): def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
import paddle import paddle
from unstructured_paddleocr import PaddleOCR from unstructured_paddleocr import PaddleOCR
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
# Disable signal handlers at C++ level upon failing # Disable signal handlers at C++ level upon failing
# ref: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/ # ref: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/
# disable_signal_handler_en.html#disable-signal-handler # disable_signal_handler_en.html#disable-signal-handler
@ -55,7 +59,7 @@ class OCRAgentPaddle(OCRAgent):
) )
return paddle_ocr return paddle_ocr
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str: def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
ocr_regions = self.get_layout_from_image(image) ocr_regions = self.get_layout_from_image(image)
return "\n\n".join([r.text for r in ocr_regions]) return "\n\n".join([r.text for r in ocr_regions])
@ -63,8 +67,8 @@ class OCRAgentPaddle(OCRAgent):
return False return False
def get_layout_from_image( def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["TextRegion"]: ) -> list[TextRegion]:
"""Get the OCR regions from image as a list of text regions with paddle.""" """Get the OCR regions from image as a list of text regions with paddle."""
logger.info("Processing entire page OCR with paddle...") logger.info("Processing entire page OCR with paddle...")
@ -79,8 +83,8 @@ class OCRAgentPaddle(OCRAgent):
@requires_dependencies("unstructured_inference") @requires_dependencies("unstructured_inference")
def get_layout_elements_from_image( def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["LayoutElement"]: ) -> list[LayoutElement]:
from unstructured.partition.pdf_image.inference_utils import build_layout_element from unstructured.partition.pdf_image.inference_utils import build_layout_element
ocr_regions = self.get_layout_from_image( ocr_regions = self.get_layout_from_image(
@ -102,10 +106,8 @@ class OCRAgentPaddle(OCRAgent):
] ]
@requires_dependencies("unstructured_inference") @requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data: list) -> List["TextRegion"]: def parse_data(self, ocr_data: list[Any]) -> list[TextRegion]:
""" """Parse the OCR result data to extract a list of TextRegion objects from paddle.
Parse the OCR result data to extract a list of TextRegion objects from
paddle.
The function processes the OCR result dictionary, looking for bounding The function processes the OCR result dictionary, looking for bounding
box information and associated text to create instances of the TextRegion box information and associated text to create instances of the TextRegion
@ -115,7 +117,7 @@ class OCRAgentPaddle(OCRAgent):
- ocr_data (list): A list containing the OCR result data - ocr_data (list): A list containing the OCR result data
Returns: Returns:
- List[TextRegion]: A list of TextRegion objects, each representing a - list[TextRegion]: A list of TextRegion objects, each representing a
detected text region within the OCR-ed image. detected text region within the OCR-ed image.
Note: Note:
@ -125,7 +127,7 @@ class OCRAgentPaddle(OCRAgent):
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
text_regions = [] text_regions: list[TextRegion] = []
for idx in range(len(ocr_data)): for idx in range(len(ocr_data)):
res = ocr_data[idx] res = ocr_data[idx]
if not res: if not res:
@ -142,12 +144,7 @@ class OCRAgentPaddle(OCRAgent):
cleaned_text = text.strip() cleaned_text = text.strip()
if cleaned_text: if cleaned_text:
text_region = build_text_region_from_coords( text_region = build_text_region_from_coords(
x1, x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_PADDLE
y1,
x2,
y2,
text=cleaned_text,
source=Source.OCR_PADDLE,
) )
text_regions.append(text_region) text_regions.append(text_region)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import cv2 import cv2
@ -26,21 +28,17 @@ if TYPE_CHECKING:
class OCRAgentTesseract(OCRAgent): class OCRAgentTesseract(OCRAgent):
def load_agent(self): """OCR service implementation for Tesseract."""
pass
def is_text_sorted(self): def is_text_sorted(self):
return True return True
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str: def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
return unstructured_pytesseract.image_to_string( return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages)
np.array(image),
lang=ocr_languages,
)
def get_layout_from_image( def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["TextRegion"]: ) -> List[TextRegion]:
"""Get the OCR regions from image as a list of text regions with tesseract.""" """Get the OCR regions from image as a list of text regions with tesseract."""
logger.info("Processing entire page OCR with tesseract...") logger.info("Processing entire page OCR with tesseract...")
@ -58,7 +56,7 @@ class OCRAgentTesseract(OCRAgent):
# depend on type of characters (font, language, etc); be careful about this # depend on type of characters (font, language, etc); be careful about this
# functionality # functionality
text_height = ocr_df[TESSERACT_TEXT_HEIGHT].quantile( text_height = ocr_df[TESSERACT_TEXT_HEIGHT].quantile(
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE, env_config.TESSERACT_TEXT_HEIGHT_QUANTILE
) )
if ( if (
text_height < env_config.TESSERACT_MIN_TEXT_HEIGHT text_height < env_config.TESSERACT_MIN_TEXT_HEIGHT
@ -87,7 +85,7 @@ class OCRAgentTesseract(OCRAgent):
@requires_dependencies("unstructured_inference") @requires_dependencies("unstructured_inference")
def get_layout_elements_from_image( def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["LayoutElement"]: ) -> List["LayoutElement"]:
from unstructured.partition.pdf_image.inference_utils import ( from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions, build_layout_elements_from_ocr_regions,
@ -118,9 +116,7 @@ class OCRAgentTesseract(OCRAgent):
@requires_dependencies("unstructured_inference") @requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]: def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
""" """Parse the OCR result data to extract a list of TextRegion objects from tesseract.
Parse the OCR result data to extract a list of TextRegion objects from
tesseract.
The function processes the OCR result data frame, looking for bounding The function processes the OCR result data frame, looking for bounding
box information and associated text to create instances of the TextRegion box information and associated text to create instances of the TextRegion
@ -150,7 +146,7 @@ class OCRAgentTesseract(OCRAgent):
if zoom <= 0: if zoom <= 0:
zoom = 1 zoom = 1
text_regions = [] text_regions: list[TextRegion] = []
for idtx in ocr_data.itertuples(): for idtx in ocr_data.itertuples():
text = idtx.text text = idtx.text
if not text: if not text:
@ -164,19 +160,14 @@ class OCRAgentTesseract(OCRAgent):
x2 = (idtx.left + idtx.width) / zoom x2 = (idtx.left + idtx.width) / zoom
y2 = (idtx.top + idtx.height) / zoom y2 = (idtx.top + idtx.height) / zoom
text_region = build_text_region_from_coords( text_region = build_text_region_from_coords(
x1, x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_TESSERACT
y1,
x2,
y2,
text=cleaned_text,
source=Source.OCR_TESSERACT,
) )
text_regions.append(text_region) text_regions.append(text_region)
return text_regions return text_regions
def zoom_image(image: PILImage, zoom: float = 1) -> PILImage: def zoom_image(image: PILImage.Image, zoom: float = 1) -> PILImage.Image:
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by """scale an image based on the zoom factor using cv2; the scaled image is post processed by
dilation then erosion to improve edge sharpness for OCR tasks""" dilation then erosion to improve edge sharpness for OCR tasks"""
if zoom <= 0: if zoom <= 0: