mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-08-02 05:49:24 +00:00
rfctr: improve typing in OCR modules (#2893)
**Summary** In preparation for using OCR for partitioners other than PDF, clean up typing in the OCR module.
This commit is contained in:
parent
cb1e91058e
commit
f752849c41
@ -1,4 +1,4 @@
|
|||||||
## 0.13.3-dev5
|
## 0.13.3-dev6
|
||||||
|
|
||||||
### Enhancements
|
### Enhancements
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "0.13.3-dev5" # pragma: no cover
|
__version__ = "0.13.3-dev6" # pragma: no cover
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import TYPE_CHECKING, BinaryIO, Dict, List, Optional, Union, cast
|
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
|
||||||
|
|
||||||
import pdf2image
|
import pdf2image
|
||||||
|
|
||||||
@ -39,7 +41,7 @@ if "OMP_THREAD_LIMIT" not in os.environ:
|
|||||||
|
|
||||||
|
|
||||||
def process_data_with_ocr(
|
def process_data_with_ocr(
|
||||||
data: Union[bytes, BinaryIO],
|
data: bytes | IO[bytes],
|
||||||
out_layout: "DocumentLayout",
|
out_layout: "DocumentLayout",
|
||||||
extracted_layout: List[List["TextRegion"]],
|
extracted_layout: List[List["TextRegion"]],
|
||||||
is_image: bool = False,
|
is_image: bool = False,
|
||||||
@ -76,7 +78,8 @@ def process_data_with_ocr(
|
|||||||
DocumentLayout: The merged layout information obtained after OCR processing.
|
DocumentLayout: The merged layout information obtained after OCR processing.
|
||||||
"""
|
"""
|
||||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
with tempfile.NamedTemporaryFile() as tmp_file:
|
||||||
tmp_file.write(data.read() if hasattr(data, "read") else data)
|
data_bytes = data if isinstance(data, bytes) else data.read()
|
||||||
|
tmp_file.write(data_bytes)
|
||||||
tmp_file.flush()
|
tmp_file.flush()
|
||||||
merged_layouts = process_file_with_ocr(
|
merged_layouts = process_file_with_ocr(
|
||||||
filename=tmp_file.name,
|
filename=tmp_file.name,
|
||||||
@ -131,7 +134,7 @@ def process_file_with_ocr(
|
|||||||
|
|
||||||
from unstructured_inference.inference.layout import DocumentLayout
|
from unstructured_inference.inference.layout import DocumentLayout
|
||||||
|
|
||||||
merged_page_layouts = []
|
merged_page_layouts: list[PageLayout] = []
|
||||||
try:
|
try:
|
||||||
if is_image:
|
if is_image:
|
||||||
with PILImage.open(filename) as images:
|
with PILImage.open(filename) as images:
|
||||||
@ -182,7 +185,7 @@ def process_file_with_ocr(
|
|||||||
@requires_dependencies("unstructured_inference")
|
@requires_dependencies("unstructured_inference")
|
||||||
def supplement_page_layout_with_ocr(
|
def supplement_page_layout_with_ocr(
|
||||||
page_layout: "PageLayout",
|
page_layout: "PageLayout",
|
||||||
image: PILImage,
|
image: PILImage.Image,
|
||||||
infer_table_structure: bool = False,
|
infer_table_structure: bool = False,
|
||||||
ocr_languages: str = "eng",
|
ocr_languages: str = "eng",
|
||||||
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
||||||
@ -254,7 +257,7 @@ def supplement_page_layout_with_ocr(
|
|||||||
|
|
||||||
def supplement_element_with_table_extraction(
|
def supplement_element_with_table_extraction(
|
||||||
elements: List["LayoutElement"],
|
elements: List["LayoutElement"],
|
||||||
image: PILImage,
|
image: PILImage.Image,
|
||||||
tables_agent: "UnstructuredTableTransformerModel",
|
tables_agent: "UnstructuredTableTransformerModel",
|
||||||
ocr_languages: str = "eng",
|
ocr_languages: str = "eng",
|
||||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||||
@ -289,12 +292,12 @@ def supplement_element_with_table_extraction(
|
|||||||
|
|
||||||
|
|
||||||
def get_table_tokens(
|
def get_table_tokens(
|
||||||
table_element_image: PILImage,
|
table_element_image: PILImage.Image,
|
||||||
ocr_languages: str = "eng",
|
ocr_languages: str = "eng",
|
||||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||||
extracted_regions: Optional[List["TextRegion"]] = None,
|
extracted_regions: Optional[List["TextRegion"]] = None,
|
||||||
table_element: Optional["LayoutElement"] = None,
|
table_element: Optional["LayoutElement"] = None,
|
||||||
) -> List[Dict]:
|
) -> List[dict[str, Any]]:
|
||||||
"""Get OCR tokens from either paddleocr or tesseract"""
|
"""Get OCR tokens from either paddleocr or tesseract"""
|
||||||
|
|
||||||
ocr_layout = ocr_agent.get_layout_from_image(
|
ocr_layout = ocr_agent.get_layout_from_image(
|
||||||
@ -417,7 +420,7 @@ def supplement_layout_with_ocr_elements(
|
|||||||
build_layout_elements_from_ocr_regions,
|
build_layout_elements_from_ocr_regions,
|
||||||
)
|
)
|
||||||
|
|
||||||
ocr_regions_to_remove = []
|
ocr_regions_to_remove: list[TextRegion] = []
|
||||||
for ocr_region in ocr_layout:
|
for ocr_region in ocr_layout:
|
||||||
for el in layout:
|
for el in layout:
|
||||||
ocr_region_is_subregion_of_out_el = ocr_region.bbox.is_almost_subregion_of(
|
ocr_region_is_subregion_of_out_el = ocr_region.bbox.is_almost_subregion_of(
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, List
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
|
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from PIL import PILImage
|
from PIL import Image as PILImage
|
||||||
from unstructured_inference.inference.elements import TextRegion
|
from unstructured_inference.inference.elements import TextRegion
|
||||||
from unstructured_inference.inference.layoutelement import (
|
from unstructured_inference.inference.layoutelement import (
|
||||||
LayoutElement,
|
LayoutElement,
|
||||||
@ -14,31 +16,26 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class OCRAgent(ABC):
|
class OCRAgent(ABC):
|
||||||
def __init__(self):
|
"""Defines the interface for an Optical Character Recognition (OCR) service."""
|
||||||
self.agent = self.load_agent()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_agent(self, language: str) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_text_sorted(self) -> bool:
|
def is_text_sorted(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_text_from_image(self, image: "PILImage", ocr_languages: str = "eng") -> str:
|
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_layout_from_image(
|
def get_layout_from_image(
|
||||||
self, image: "PILImage", ocr_languages: str = "eng"
|
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||||
) -> List["TextRegion"]:
|
) -> list[TextRegion]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_layout_elements_from_image(
|
def get_layout_elements_from_image(
|
||||||
self, image: "PILImage", ocr_languages: str = "eng"
|
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||||
) -> List["LayoutElement"]:
|
) -> list[LayoutElement]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -51,6 +48,6 @@ class OCRAgent(ABC):
|
|||||||
return loaded_class()
|
return loaded_class()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Environment variable OCR_AGENT module name {module_name}",
|
f"Environment variable OCR_AGENT module name {module_name}, must be set to a"
|
||||||
f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
|
f" whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
|
||||||
)
|
)
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from typing import TYPE_CHECKING, List
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from unstructured.documents.elements import ElementType
|
from unstructured.documents.elements import ElementType
|
||||||
from unstructured.logger import logger
|
from unstructured.logger import logger
|
||||||
from unstructured.partition.utils.constants import (
|
from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source
|
||||||
DEFAULT_PADDLE_LANG,
|
|
||||||
Source,
|
|
||||||
)
|
|
||||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||||
from unstructured.utils import requires_dependencies
|
from unstructured.utils import requires_dependencies
|
||||||
|
|
||||||
@ -18,12 +17,17 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class OCRAgentPaddle(OCRAgent):
|
class OCRAgentPaddle(OCRAgent):
|
||||||
|
"""OCR service implementation for PaddleOCR."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.agent = self.load_agent()
|
||||||
|
|
||||||
def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
|
def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
|
||||||
|
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from unstructured_paddleocr import PaddleOCR
|
from unstructured_paddleocr import PaddleOCR
|
||||||
|
|
||||||
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
|
|
||||||
|
|
||||||
# Disable signal handlers at C++ level upon failing
|
# Disable signal handlers at C++ level upon failing
|
||||||
# ref: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/
|
# ref: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/
|
||||||
# disable_signal_handler_en.html#disable-signal-handler
|
# disable_signal_handler_en.html#disable-signal-handler
|
||||||
@ -55,7 +59,7 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
)
|
)
|
||||||
return paddle_ocr
|
return paddle_ocr
|
||||||
|
|
||||||
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
|
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||||
ocr_regions = self.get_layout_from_image(image)
|
ocr_regions = self.get_layout_from_image(image)
|
||||||
return "\n\n".join([r.text for r in ocr_regions])
|
return "\n\n".join([r.text for r in ocr_regions])
|
||||||
|
|
||||||
@ -63,8 +67,8 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_layout_from_image(
|
def get_layout_from_image(
|
||||||
self, image: PILImage, ocr_languages: str = "eng"
|
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||||
) -> List["TextRegion"]:
|
) -> list[TextRegion]:
|
||||||
"""Get the OCR regions from image as a list of text regions with paddle."""
|
"""Get the OCR regions from image as a list of text regions with paddle."""
|
||||||
|
|
||||||
logger.info("Processing entire page OCR with paddle...")
|
logger.info("Processing entire page OCR with paddle...")
|
||||||
@ -79,8 +83,8 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
|
|
||||||
@requires_dependencies("unstructured_inference")
|
@requires_dependencies("unstructured_inference")
|
||||||
def get_layout_elements_from_image(
|
def get_layout_elements_from_image(
|
||||||
self, image: PILImage, ocr_languages: str = "eng"
|
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||||
) -> List["LayoutElement"]:
|
) -> list[LayoutElement]:
|
||||||
from unstructured.partition.pdf_image.inference_utils import build_layout_element
|
from unstructured.partition.pdf_image.inference_utils import build_layout_element
|
||||||
|
|
||||||
ocr_regions = self.get_layout_from_image(
|
ocr_regions = self.get_layout_from_image(
|
||||||
@ -102,10 +106,8 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@requires_dependencies("unstructured_inference")
|
@requires_dependencies("unstructured_inference")
|
||||||
def parse_data(self, ocr_data: list) -> List["TextRegion"]:
|
def parse_data(self, ocr_data: list[Any]) -> list[TextRegion]:
|
||||||
"""
|
"""Parse the OCR result data to extract a list of TextRegion objects from paddle.
|
||||||
Parse the OCR result data to extract a list of TextRegion objects from
|
|
||||||
paddle.
|
|
||||||
|
|
||||||
The function processes the OCR result dictionary, looking for bounding
|
The function processes the OCR result dictionary, looking for bounding
|
||||||
box information and associated text to create instances of the TextRegion
|
box information and associated text to create instances of the TextRegion
|
||||||
@ -115,7 +117,7 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
- ocr_data (list): A list containing the OCR result data
|
- ocr_data (list): A list containing the OCR result data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- List[TextRegion]: A list of TextRegion objects, each representing a
|
- list[TextRegion]: A list of TextRegion objects, each representing a
|
||||||
detected text region within the OCR-ed image.
|
detected text region within the OCR-ed image.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -125,7 +127,7 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
|
|
||||||
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
|
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
|
||||||
|
|
||||||
text_regions = []
|
text_regions: list[TextRegion] = []
|
||||||
for idx in range(len(ocr_data)):
|
for idx in range(len(ocr_data)):
|
||||||
res = ocr_data[idx]
|
res = ocr_data[idx]
|
||||||
if not res:
|
if not res:
|
||||||
@ -142,12 +144,7 @@ class OCRAgentPaddle(OCRAgent):
|
|||||||
cleaned_text = text.strip()
|
cleaned_text = text.strip()
|
||||||
if cleaned_text:
|
if cleaned_text:
|
||||||
text_region = build_text_region_from_coords(
|
text_region = build_text_region_from_coords(
|
||||||
x1,
|
x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_PADDLE
|
||||||
y1,
|
|
||||||
x2,
|
|
||||||
y2,
|
|
||||||
text=cleaned_text,
|
|
||||||
source=Source.OCR_PADDLE,
|
|
||||||
)
|
)
|
||||||
text_regions.append(text_region)
|
text_regions.append(text_region)
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -26,21 +28,17 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class OCRAgentTesseract(OCRAgent):
|
class OCRAgentTesseract(OCRAgent):
|
||||||
def load_agent(self):
|
"""OCR service implementation for Tesseract."""
|
||||||
pass
|
|
||||||
|
|
||||||
def is_text_sorted(self):
|
def is_text_sorted(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
|
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
|
||||||
return unstructured_pytesseract.image_to_string(
|
return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages)
|
||||||
np.array(image),
|
|
||||||
lang=ocr_languages,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layout_from_image(
|
def get_layout_from_image(
|
||||||
self, image: PILImage, ocr_languages: str = "eng"
|
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||||
) -> List["TextRegion"]:
|
) -> List[TextRegion]:
|
||||||
"""Get the OCR regions from image as a list of text regions with tesseract."""
|
"""Get the OCR regions from image as a list of text regions with tesseract."""
|
||||||
|
|
||||||
logger.info("Processing entire page OCR with tesseract...")
|
logger.info("Processing entire page OCR with tesseract...")
|
||||||
@ -58,7 +56,7 @@ class OCRAgentTesseract(OCRAgent):
|
|||||||
# depend on type of characters (font, language, etc); be careful about this
|
# depend on type of characters (font, language, etc); be careful about this
|
||||||
# functionality
|
# functionality
|
||||||
text_height = ocr_df[TESSERACT_TEXT_HEIGHT].quantile(
|
text_height = ocr_df[TESSERACT_TEXT_HEIGHT].quantile(
|
||||||
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE,
|
env_config.TESSERACT_TEXT_HEIGHT_QUANTILE
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
text_height < env_config.TESSERACT_MIN_TEXT_HEIGHT
|
text_height < env_config.TESSERACT_MIN_TEXT_HEIGHT
|
||||||
@ -87,7 +85,7 @@ class OCRAgentTesseract(OCRAgent):
|
|||||||
|
|
||||||
@requires_dependencies("unstructured_inference")
|
@requires_dependencies("unstructured_inference")
|
||||||
def get_layout_elements_from_image(
|
def get_layout_elements_from_image(
|
||||||
self, image: PILImage, ocr_languages: str = "eng"
|
self, image: PILImage.Image, ocr_languages: str = "eng"
|
||||||
) -> List["LayoutElement"]:
|
) -> List["LayoutElement"]:
|
||||||
from unstructured.partition.pdf_image.inference_utils import (
|
from unstructured.partition.pdf_image.inference_utils import (
|
||||||
build_layout_elements_from_ocr_regions,
|
build_layout_elements_from_ocr_regions,
|
||||||
@ -118,9 +116,7 @@ class OCRAgentTesseract(OCRAgent):
|
|||||||
|
|
||||||
@requires_dependencies("unstructured_inference")
|
@requires_dependencies("unstructured_inference")
|
||||||
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
|
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
|
||||||
"""
|
"""Parse the OCR result data to extract a list of TextRegion objects from tesseract.
|
||||||
Parse the OCR result data to extract a list of TextRegion objects from
|
|
||||||
tesseract.
|
|
||||||
|
|
||||||
The function processes the OCR result data frame, looking for bounding
|
The function processes the OCR result data frame, looking for bounding
|
||||||
box information and associated text to create instances of the TextRegion
|
box information and associated text to create instances of the TextRegion
|
||||||
@ -150,7 +146,7 @@ class OCRAgentTesseract(OCRAgent):
|
|||||||
if zoom <= 0:
|
if zoom <= 0:
|
||||||
zoom = 1
|
zoom = 1
|
||||||
|
|
||||||
text_regions = []
|
text_regions: list[TextRegion] = []
|
||||||
for idtx in ocr_data.itertuples():
|
for idtx in ocr_data.itertuples():
|
||||||
text = idtx.text
|
text = idtx.text
|
||||||
if not text:
|
if not text:
|
||||||
@ -164,19 +160,14 @@ class OCRAgentTesseract(OCRAgent):
|
|||||||
x2 = (idtx.left + idtx.width) / zoom
|
x2 = (idtx.left + idtx.width) / zoom
|
||||||
y2 = (idtx.top + idtx.height) / zoom
|
y2 = (idtx.top + idtx.height) / zoom
|
||||||
text_region = build_text_region_from_coords(
|
text_region = build_text_region_from_coords(
|
||||||
x1,
|
x1, y1, x2, y2, text=cleaned_text, source=Source.OCR_TESSERACT
|
||||||
y1,
|
|
||||||
x2,
|
|
||||||
y2,
|
|
||||||
text=cleaned_text,
|
|
||||||
source=Source.OCR_TESSERACT,
|
|
||||||
)
|
)
|
||||||
text_regions.append(text_region)
|
text_regions.append(text_region)
|
||||||
|
|
||||||
return text_regions
|
return text_regions
|
||||||
|
|
||||||
|
|
||||||
def zoom_image(image: PILImage, zoom: float = 1) -> PILImage:
|
def zoom_image(image: PILImage.Image, zoom: float = 1) -> PILImage.Image:
|
||||||
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
|
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
|
||||||
dilation then erosion to improve edge sharpness for OCR tasks"""
|
dilation then erosion to improve edge sharpness for OCR tasks"""
|
||||||
if zoom <= 0:
|
if zoom <= 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user