mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-08-01 13:29:45 +00:00
rfctr: improve typing in OCR modules (#2893)
**Summary** In preparation for using OCR for partitioners other than PDF, clean up typing in the OCR module.
This commit is contained in:
parent
cb1e91058e
commit
f752849c41
@ -1,4 +1,4 @@
|
||||
## 0.13.3-dev5
|
||||
## 0.13.3-dev6
|
||||
|
||||
### Enhancements
|
||||
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.13.3-dev5" # pragma: no cover
|
||||
__version__ = "0.13.3-dev6" # pragma: no cover
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, BinaryIO, Dict, List, Optional, Union, cast
|
||||
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
import pdf2image
|
||||
|
||||
@ -39,7 +41,7 @@ if "OMP_THREAD_LIMIT" not in os.environ:
|
||||
|
||||
|
||||
def process_data_with_ocr(
|
||||
data: Union[bytes, BinaryIO],
|
||||
data: bytes | IO[bytes],
|
||||
out_layout: "DocumentLayout",
|
||||
extracted_layout: List[List["TextRegion"]],
|
||||
is_image: bool = False,
|
||||
@ -76,7 +78,8 @@ def process_data_with_ocr(
|
||||
DocumentLayout: The merged layout information obtained after OCR processing.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
||||
tmp_file.write(data.read() if hasattr(data, "read") else data)
|
||||
data_bytes = data if isinstance(data, bytes) else data.read()
|
||||
tmp_file.write(data_bytes)
|
||||
tmp_file.flush()
|
||||
merged_layouts = process_file_with_ocr(
|
||||
filename=tmp_file.name,
|
||||
@ -131,7 +134,7 @@ def process_file_with_ocr(
|
||||
|
||||
from unstructured_inference.inference.layout import DocumentLayout
|
||||
|
||||
merged_page_layouts = []
|
||||
merged_page_layouts: list[PageLayout] = []
|
||||
try:
|
||||
if is_image:
|
||||
with PILImage.open(filename) as images:
|
||||
@ -182,7 +185,7 @@ def process_file_with_ocr(
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def supplement_page_layout_with_ocr(
|
||||
page_layout: "PageLayout",
|
||||
image: PILImage,
|
||||
image: PILImage.Image,
|
||||
infer_table_structure: bool = False,
|
||||
ocr_languages: str = "eng",
|
||||
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
||||
@ -254,7 +257,7 @@ def supplement_page_layout_with_ocr(
|
||||
|
||||
def supplement_element_with_table_extraction(
|
||||
elements: List["LayoutElement"],
|
||||
image: PILImage,
|
||||
image: PILImage.Image,
|
||||
tables_agent: "UnstructuredTableTransformerModel",
|
||||
ocr_languages: str = "eng",
|
||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||
@ -289,12 +292,12 @@ def supplement_element_with_table_extraction(
|
||||
|
||||
|
||||
def get_table_tokens(
|
||||
table_element_image: PILImage,
|
||||
table_element_image: PILImage.Image,
|
||||
ocr_languages: str = "eng",
|
||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||
extracted_regions: Optional[List["TextRegion"]] = None,
|
||||
table_element: Optional["LayoutElement"] = None,
|
||||
) -> List[Dict]:
|
||||
) -> List[dict[str, Any]]:
|
||||
"""Get OCR tokens from either paddleocr or tesseract"""
|
||||
|
||||
ocr_layout = ocr_agent.get_layout_from_image(
|
||||
@ -417,7 +420,7 @@ def supplement_layout_with_ocr_elements(
|
||||
build_layout_elements_from_ocr_regions,
|
||||
)
|
||||
|
||||
ocr_regions_to_remove = []
|
||||
ocr_regions_to_remove: list[TextRegion] = []
|
||||
for ocr_region in ocr_layout:
|
||||
for el in layout:
|
||||
ocr_region_is_subregion_of_out_el = ocr_region.bbox.is_almost_subregion_of(
|
||||
|
@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import PILImage
|
||||
from PIL import Image as PILImage
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
@ -14,31 +16,26 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OCRAgent(ABC):
|
||||
def __init__(self):
|
||||
self.agent = self.load_agent()
|
||||
|
||||
@abstractmethod
|
||||
def load_agent(self, language: str) -> Any:
|
||||
pass
|
||||
"""Defines the interface for an Optical Character Recognition (OCR) service."""
|
||||
|
||||
@abstractmethod
|
||||
def is_text_sorted(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text_from_image(self, image: "PILImage", ocr_languages: str = "eng") -> str:
|
||||
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_layout_from_image(
|
||||
self, image: "PILImage", ocr_languages: str = "eng"
|
||||
) -> List["TextRegion"]:
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[TextRegion]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_layout_elements_from_image(
|
||||
self, image: "PILImage", ocr_languages: str = "eng"
|
||||
) -> List["LayoutElement"]:
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[LayoutElement]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -51,6 +48,6 @@ class OCRAgent(ABC):
|
||||
return loaded_class()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Environment variable OCR_AGENT module name {module_name}",
|
||||
f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
|
||||
f"Environment variable OCR_AGENT module name {module_name}, must be set to a"
|
||||
f" whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
|
||||
)
|
||||
|
@ -1,14 +1,13 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.logger import logger
|
||||
from unstructured.partition.utils.constants import (
|
||||
DEFAULT_PADDLE_LANG,
|
||||
Source,
|
||||
)
|
||||
from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
@ -18,12 +17,17 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OCRAgentPaddle(OCRAgent):
|
||||
"""OCR service implementation for PaddleOCR."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = self.load_agent()
|
||||
|
||||
def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
|
||||
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
|
||||
|
||||
import paddle
|
||||
from unstructured_paddleocr import PaddleOCR
|
||||
|
||||
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
|
||||
|
||||
# Disable signal handlers at C++ level upon failing
|
||||
# ref: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/
|
||||
# disable_signal_handler_en.html#disable-signal-handler
|
||||
@ -55,7 +59,7 @@ class OCRAgentPaddle(OCRAgent):
|
||||
)
|
||||
return paddle_ocr
|
||||
|
||||
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
|
||||
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||
ocr_regions = self.get_layout_from_image(image)
|
||||
return "\n\n".join([r.text for r in ocr_regions])
|
||||
|
||||
@ -63,8 +67,8 @@ class OCRAgentPaddle(OCRAgent):
|
||||
return False
|
||||
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List["TextRegion"]:
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[TextRegion]:
|
||||
"""Get the OCR regions from image as a list of text regions with paddle."""
|
||||
|
||||
logger.info("Processing entire page OCR with paddle...")
|
||||
@ -79,8 +83,8 @@ class OCRAgentPaddle(OCRAgent):
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List["LayoutElement"]:
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> list[LayoutElement]:
|
||||
from unstructured.partition.pdf_image.inference_utils import build_layout_element
|
||||
|
||||
ocr_regions = self.get_layout_from_image(
|
||||
@ -102,10 +106,8 @@ class OCRAgentPaddle(OCRAgent):
|
||||
]
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def parse_data(self, ocr_data: list) -> List["TextRegion"]:
|
||||
"""
|
||||
Parse the OCR result data to extract a list of TextRegion objects from
|
||||
paddle.
|
||||
def parse_data(self, ocr_data: list[Any]) -> list[TextRegion]:
|
||||
"""Parse the OCR result data to extract a list of TextRegion objects from paddle.
|
||||
|
||||
The function processes the OCR result dictionary, looking for bounding
|
||||
box information and associated text to create instances of the TextRegion
|
||||
@ -115,7 +117,7 @@ class OCRAgentPaddle(OCRAgent):
|
||||
- ocr_data (list): A list containing the OCR result data
|
||||
|
||||
Returns:
|
||||
- List[TextRegion]: A list of TextRegion objects, each representing a
|
||||
- list[TextRegion]: A list of TextRegion objects, each representing a
|
||||
detected text region within the OCR-ed image.
|
||||
|
||||
Note:
|
||||
@ -125,7 +127,7 @@ class OCRAgentPaddle(OCRAgent):
|
||||
|
||||
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
|
||||
|
||||
text_regions = []
|
||||
text_regions: list[TextRegion] = []
|
||||
for idx in range(len(ocr_data)):
|
||||
res = ocr_data[idx]
|
||||
if not res:
|
||||
@ -142,12 +144,7 @@ class OCRAgentPaddle(OCRAgent):
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text:
|
||||
text_region = build_text_region_from_coords(
|
||||
x1,
|
||||
y1,
|
||||
x2,
|
||||
y2,
|
||||
text=cleaned_text,
|
||||
source=Source.OCR_PADDLE,
|
||||
x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_PADDLE
|
||||
)
|
||||
text_regions.append(text_region)
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import cv2
|
||||
@ -26,21 +28,17 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OCRAgentTesseract(OCRAgent):
|
||||
def load_agent(self):
|
||||
pass
|
||||
"""OCR service implementation for Tesseract."""
|
||||
|
||||
def is_text_sorted(self):
|
||||
return True
|
||||
|
||||
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
|
||||
return unstructured_pytesseract.image_to_string(
|
||||
np.array(image),
|
||||
lang=ocr_languages,
|
||||
)
|
||||
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||
return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages)
|
||||
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List["TextRegion"]:
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> List[TextRegion]:
|
||||
"""Get the OCR regions from image as a list of text regions with tesseract."""
|
||||
|
||||
logger.info("Processing entire page OCR with tesseract...")
|
||||
@ -58,7 +56,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
# depend on type of characters (font, language, etc); be careful about this
|
||||
# functionality
|
||||
text_height = ocr_df[TESSERACT_TEXT_HEIGHT].quantile(
|
||||
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE,
|
||||
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE
|
||||
)
|
||||
if (
|
||||
text_height < env_config.TESSERACT_MIN_TEXT_HEIGHT
|
||||
@ -87,7 +85,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||
) -> List["LayoutElement"]:
|
||||
from unstructured.partition.pdf_image.inference_utils import (
|
||||
build_layout_elements_from_ocr_regions,
|
||||
@ -118,9 +116,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
|
||||
"""
|
||||
Parse the OCR result data to extract a list of TextRegion objects from
|
||||
tesseract.
|
||||
"""Parse the OCR result data to extract a list of TextRegion objects from tesseract.
|
||||
|
||||
The function processes the OCR result data frame, looking for bounding
|
||||
box information and associated text to create instances of the TextRegion
|
||||
@ -150,7 +146,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
if zoom <= 0:
|
||||
zoom = 1
|
||||
|
||||
text_regions = []
|
||||
text_regions: list[TextRegion] = []
|
||||
for idtx in ocr_data.itertuples():
|
||||
text = idtx.text
|
||||
if not text:
|
||||
@ -164,19 +160,14 @@ class OCRAgentTesseract(OCRAgent):
|
||||
x2 = (idtx.left + idtx.width) / zoom
|
||||
y2 = (idtx.top + idtx.height) / zoom
|
||||
text_region = build_text_region_from_coords(
|
||||
x1,
|
||||
y1,
|
||||
x2,
|
||||
y2,
|
||||
text=cleaned_text,
|
||||
source=Source.OCR_TESSERACT,
|
||||
x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_TESSERACT
|
||||
)
|
||||
text_regions.append(text_region)
|
||||
|
||||
return text_regions
|
||||
|
||||
|
||||
def zoom_image(image: PILImage, zoom: float = 1) -> PILImage:
|
||||
def zoom_image(image: PILImage.Image, zoom: float = 1) -> PILImage.Image:
|
||||
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
|
||||
dilation then erosion to improve edge sharpness for OCR tasks"""
|
||||
if zoom <= 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user