rfctr: improve typing in OCR modules (#2893)

**Summary**
In preparation for using OCR for partitioners other than PDF, clean up
typing in the OCR module.
This commit is contained in:
Steve Canny 2024-04-15 20:55:35 -07:00 committed by GitHub
parent cb1e91058e
commit f752849c41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 60 additions and 72 deletions

View File

@ -1,4 +1,4 @@
## 0.13.3-dev5
## 0.13.3-dev6
### Enhancements

View File

@ -1 +1 @@
__version__ = "0.13.3-dev5" # pragma: no cover
__version__ = "0.13.3-dev6" # pragma: no cover

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import os
import tempfile
from typing import TYPE_CHECKING, BinaryIO, Dict, List, Optional, Union, cast
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
import pdf2image
@ -39,7 +41,7 @@ if "OMP_THREAD_LIMIT" not in os.environ:
def process_data_with_ocr(
data: Union[bytes, BinaryIO],
data: bytes | IO[bytes],
out_layout: "DocumentLayout",
extracted_layout: List[List["TextRegion"]],
is_image: bool = False,
@ -76,7 +78,8 @@ def process_data_with_ocr(
DocumentLayout: The merged layout information obtained after OCR processing.
"""
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(data.read() if hasattr(data, "read") else data)
data_bytes = data if isinstance(data, bytes) else data.read()
tmp_file.write(data_bytes)
tmp_file.flush()
merged_layouts = process_file_with_ocr(
filename=tmp_file.name,
@ -131,7 +134,7 @@ def process_file_with_ocr(
from unstructured_inference.inference.layout import DocumentLayout
merged_page_layouts = []
merged_page_layouts: list[PageLayout] = []
try:
if is_image:
with PILImage.open(filename) as images:
@ -182,7 +185,7 @@ def process_file_with_ocr(
@requires_dependencies("unstructured_inference")
def supplement_page_layout_with_ocr(
page_layout: "PageLayout",
image: PILImage,
image: PILImage.Image,
infer_table_structure: bool = False,
ocr_languages: str = "eng",
ocr_mode: str = OCRMode.FULL_PAGE.value,
@ -254,7 +257,7 @@ def supplement_page_layout_with_ocr(
def supplement_element_with_table_extraction(
elements: List["LayoutElement"],
image: PILImage,
image: PILImage.Image,
tables_agent: "UnstructuredTableTransformerModel",
ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
@ -289,12 +292,12 @@ def supplement_element_with_table_extraction(
def get_table_tokens(
table_element_image: PILImage,
table_element_image: PILImage.Image,
ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
extracted_regions: Optional[List["TextRegion"]] = None,
table_element: Optional["LayoutElement"] = None,
) -> List[Dict]:
) -> List[dict[str, Any]]:
"""Get OCR tokens from either paddleocr or tesseract"""
ocr_layout = ocr_agent.get_layout_from_image(
@ -417,7 +420,7 @@ def supplement_layout_with_ocr_elements(
build_layout_elements_from_ocr_regions,
)
ocr_regions_to_remove = []
ocr_regions_to_remove: list[TextRegion] = []
for ocr_region in ocr_layout:
for el in layout:
ocr_region_is_subregion_of_out_el = ocr_region.bbox.is_almost_subregion_of(

View File

@ -1,12 +1,14 @@
from __future__ import annotations
import functools
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List
from typing import TYPE_CHECKING
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
if TYPE_CHECKING:
from PIL import PILImage
from PIL import Image as PILImage
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
@ -14,31 +16,26 @@ if TYPE_CHECKING:
class OCRAgent(ABC):
def __init__(self):
self.agent = self.load_agent()
@abstractmethod
def load_agent(self, language: str) -> Any:
pass
"""Defines the interface for an Optical Character Recognition (OCR) service."""
@abstractmethod
def is_text_sorted(self) -> bool:
pass
@abstractmethod
def get_text_from_image(self, image: "PILImage", ocr_languages: str = "eng") -> str:
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
pass
@abstractmethod
def get_layout_from_image(
self, image: "PILImage", ocr_languages: str = "eng"
) -> List["TextRegion"]:
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[TextRegion]:
pass
@abstractmethod
def get_layout_elements_from_image(
self, image: "PILImage", ocr_languages: str = "eng"
) -> List["LayoutElement"]:
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[LayoutElement]:
pass
@staticmethod
@ -51,6 +48,6 @@ class OCRAgent(ABC):
return loaded_class()
else:
raise ValueError(
f"Environment variable OCR_AGENT module name {module_name}",
f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
f"Environment variable OCR_AGENT module name {module_name}, must be set to a"
f" whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
)

View File

@ -1,14 +1,13 @@
from typing import TYPE_CHECKING, List
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from PIL import Image as PILImage
from unstructured.documents.elements import ElementType
from unstructured.logger import logger
from unstructured.partition.utils.constants import (
DEFAULT_PADDLE_LANG,
Source,
)
from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
@ -18,12 +17,17 @@ if TYPE_CHECKING:
class OCRAgentPaddle(OCRAgent):
"""OCR service implementation for PaddleOCR."""
def __init__(self):
self.agent = self.load_agent()
def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
import paddle
from unstructured_paddleocr import PaddleOCR
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
# Disable signal handlers at C++ level upon failing
# ref: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/
# disable_signal_handler_en.html#disable-signal-handler
@ -55,7 +59,7 @@ class OCRAgentPaddle(OCRAgent):
)
return paddle_ocr
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
ocr_regions = self.get_layout_from_image(image)
return "\n\n".join([r.text for r in ocr_regions])
@ -63,8 +67,8 @@ class OCRAgentPaddle(OCRAgent):
return False
def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng"
) -> List["TextRegion"]:
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[TextRegion]:
"""Get the OCR regions from image as a list of text regions with paddle."""
logger.info("Processing entire page OCR with paddle...")
@ -79,8 +83,8 @@ class OCRAgentPaddle(OCRAgent):
@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng"
) -> List["LayoutElement"]:
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[LayoutElement]:
from unstructured.partition.pdf_image.inference_utils import build_layout_element
ocr_regions = self.get_layout_from_image(
@ -102,10 +106,8 @@ class OCRAgentPaddle(OCRAgent):
]
@requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data: list) -> List["TextRegion"]:
"""
Parse the OCR result data to extract a list of TextRegion objects from
paddle.
def parse_data(self, ocr_data: list[Any]) -> list[TextRegion]:
"""Parse the OCR result data to extract a list of TextRegion objects from paddle.
The function processes the OCR result dictionary, looking for bounding
box information and associated text to create instances of the TextRegion
@ -115,7 +117,7 @@ class OCRAgentPaddle(OCRAgent):
- ocr_data (list): A list containing the OCR result data
Returns:
- List[TextRegion]: A list of TextRegion objects, each representing a
- list[TextRegion]: A list of TextRegion objects, each representing a
detected text region within the OCR-ed image.
Note:
@ -125,7 +127,7 @@ class OCRAgentPaddle(OCRAgent):
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
text_regions = []
text_regions: list[TextRegion] = []
for idx in range(len(ocr_data)):
res = ocr_data[idx]
if not res:
@ -142,12 +144,7 @@ class OCRAgentPaddle(OCRAgent):
cleaned_text = text.strip()
if cleaned_text:
text_region = build_text_region_from_coords(
x1,
y1,
x2,
y2,
text=cleaned_text,
source=Source.OCR_PADDLE,
x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_PADDLE
)
text_regions.append(text_region)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
import cv2
@ -26,21 +28,17 @@ if TYPE_CHECKING:
class OCRAgentTesseract(OCRAgent):
def load_agent(self):
pass
"""OCR service implementation for Tesseract."""
def is_text_sorted(self):
return True
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
return unstructured_pytesseract.image_to_string(
np.array(image),
lang=ocr_languages,
)
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages)
def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng"
) -> List["TextRegion"]:
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List[TextRegion]:
"""Get the OCR regions from image as a list of text regions with tesseract."""
logger.info("Processing entire page OCR with tesseract...")
@ -58,7 +56,7 @@ class OCRAgentTesseract(OCRAgent):
# depend on type of characters (font, language, etc); be careful about this
# functionality
text_height = ocr_df[TESSERACT_TEXT_HEIGHT].quantile(
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE,
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE
)
if (
text_height < env_config.TESSERACT_MIN_TEXT_HEIGHT
@ -87,7 +85,7 @@ class OCRAgentTesseract(OCRAgent):
@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng"
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["LayoutElement"]:
from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions,
@ -118,9 +116,7 @@ class OCRAgentTesseract(OCRAgent):
@requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
"""
Parse the OCR result data to extract a list of TextRegion objects from
tesseract.
"""Parse the OCR result data to extract a list of TextRegion objects from tesseract.
The function processes the OCR result data frame, looking for bounding
box information and associated text to create instances of the TextRegion
@ -150,7 +146,7 @@ class OCRAgentTesseract(OCRAgent):
if zoom <= 0:
zoom = 1
text_regions = []
text_regions: list[TextRegion] = []
for idtx in ocr_data.itertuples():
text = idtx.text
if not text:
@ -164,19 +160,14 @@ class OCRAgentTesseract(OCRAgent):
x2 = (idtx.left + idtx.width) / zoom
y2 = (idtx.top + idtx.height) / zoom
text_region = build_text_region_from_coords(
x1,
y1,
x2,
y2,
text=cleaned_text,
source=Source.OCR_TESSERACT,
x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_TESSERACT
)
text_regions.append(text_region)
return text_regions
def zoom_image(image: PILImage, zoom: float = 1) -> PILImage:
def zoom_image(image: PILImage.Image, zoom: float = 1) -> PILImage.Image:
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
dilation then erosion to improve edge sharpness for OCR tasks"""
if zoom <= 0: