refactor: ocr modules (#2492)

The purpose of this PR is to refactor OCR-related modules to reduce
unnecessary module imports to avoid potential issues (most likely due to
a "circular import").

### Summary
- add `inference_utils` module
(unstructured/partition/pdf_image/inference_utils.py) to define
unstructured-inference library related utility functions, which will
reduce importing unstructured-inference library functions in other files
- add `conftest.py` in `test_unstructured/partition/pdf_image/`
directory to define fixtures that are available to all tests in the same
directory and its subdirectories

### Testing
CI should pass
This commit is contained in:
Christine Straub 2024-02-06 09:11:55 -08:00 committed by GitHub
parent 0f0b58dfe7
commit 29b9ea7ba6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 328 additions and 245 deletions

View File

@ -1,4 +1,4 @@
## 0.12.4-dev5 ## 0.12.4-dev6
### Enhancements ### Enhancements

View File

@ -0,0 +1,78 @@
import pytest
from unstructured_inference.inference.elements import EmbeddedTextRegion
@pytest.fixture()
def mock_embedded_text_regions():
return [
EmbeddedTextRegion.from_coords(
x1=453.00277777777774,
y1=317.319341111111,
x2=711.5338541666665,
y2=358.28571222222206,
text="LayoutParser:",
),
EmbeddedTextRegion.from_coords(
x1=726.4778125,
y1=317.319341111111,
x2=760.3308594444444,
y2=357.1698966666667,
text="A",
),
EmbeddedTextRegion.from_coords(
x1=775.2748177777777,
y1=317.319341111111,
x2=917.3579885555555,
y2=357.1698966666667,
text="Unified",
),
EmbeddedTextRegion.from_coords(
x1=932.3019468888888,
y1=317.319341111111,
x2=1071.8426522222221,
y2=357.1698966666667,
text="Toolkit",
),
EmbeddedTextRegion.from_coords(
x1=1086.7866105555556,
y1=317.319341111111,
x2=1141.2105142777777,
y2=357.1698966666667,
text="for",
),
EmbeddedTextRegion.from_coords(
x1=1156.154472611111,
y1=317.319341111111,
x2=1256.334784222222,
y2=357.1698966666667,
text="Deep",
),
EmbeddedTextRegion.from_coords(
x1=437.83888888888885,
y1=367.13322999999986,
x2=610.0171992222222,
y2=406.9837855555556,
text="Learning",
),
EmbeddedTextRegion.from_coords(
x1=624.9611575555555,
y1=367.13322999999986,
x2=741.6754646666665,
y2=406.9837855555556,
text="Based",
),
EmbeddedTextRegion.from_coords(
x1=756.619423,
y1=367.13322999999986,
x2=958.3867708333332,
y2=406.9837855555556,
text="Document",
),
EmbeddedTextRegion.from_coords(
x1=973.3307291666665,
y1=367.13322999999986,
x2=1092.0535042777776,
y2=406.9837855555556,
text="Image",
),
]

View File

@ -0,0 +1,37 @@
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured.documents.elements import ElementType
from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions,
merge_text_regions,
)
def test_merge_text_regions(mock_embedded_text_regions):
expected = TextRegion.from_coords(
x1=437.83888888888885,
y1=317.319341111111,
x2=1256.334784222222,
y2=406.9837855555556,
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
)
merged_text_region = merge_text_regions(mock_embedded_text_regions)
assert merged_text_region == expected
def test_build_layout_elements_from_ocr_regions(mock_embedded_text_regions):
expected = [
LayoutElement.from_coords(
x1=437.83888888888885,
y1=317.319341111111,
x2=1256.334784222222,
y2=406.9837855555556,
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
type=ElementType.UNCATEGORIZED_TEXT,
),
]
elements = build_layout_elements_from_ocr_regions(mock_embedded_text_regions)
assert elements == expected

View File

@ -18,10 +18,6 @@ from unstructured.partition.pdf_image.ocr import pad_element_bboxes
from unstructured.partition.utils.constants import ( from unstructured.partition.utils.constants import (
Source, Source,
) )
from unstructured.partition.utils.ocr_models.ocr_interface import (
get_elements_from_ocr_regions,
merge_text_regions,
)
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
from unstructured.partition.utils.ocr_models.tesseract_ocr import ( from unstructured.partition.utils.ocr_models.tesseract_ocr import (
OCRAgentTesseract, OCRAgentTesseract,
@ -231,35 +227,6 @@ def test_aggregate_ocr_text_by_block():
assert text == expected assert text == expected
def test_merge_text_regions(mock_embedded_text_regions):
expected = TextRegion.from_coords(
x1=437.83888888888885,
y1=317.319341111111,
x2=1256.334784222222,
y2=406.9837855555556,
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
)
merged_text_region = merge_text_regions(mock_embedded_text_regions)
assert merged_text_region == expected
def test_get_elements_from_ocr_regions(mock_embedded_text_regions):
expected = [
LayoutElement.from_coords(
x1=437.83888888888885,
y1=317.319341111111,
x2=1256.334784222222,
y2=406.9837855555556,
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
type=ElementType.UNCATEGORIZED_TEXT,
),
]
elements = get_elements_from_ocr_regions(mock_embedded_text_regions)
assert elements == expected
@pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0]) @pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0])
def test_zoom_image(zoom): def test_zoom_image(zoom):
image = Image.new("RGB", (100, 100)) image = Image.new("RGB", (100, 100))
@ -280,82 +247,6 @@ def mock_layout(mock_embedded_text_regions):
] ]
@pytest.fixture()
def mock_embedded_text_regions():
return [
EmbeddedTextRegion.from_coords(
x1=453.00277777777774,
y1=317.319341111111,
x2=711.5338541666665,
y2=358.28571222222206,
text="LayoutParser:",
),
EmbeddedTextRegion.from_coords(
x1=726.4778125,
y1=317.319341111111,
x2=760.3308594444444,
y2=357.1698966666667,
text="A",
),
EmbeddedTextRegion.from_coords(
x1=775.2748177777777,
y1=317.319341111111,
x2=917.3579885555555,
y2=357.1698966666667,
text="Unified",
),
EmbeddedTextRegion.from_coords(
x1=932.3019468888888,
y1=317.319341111111,
x2=1071.8426522222221,
y2=357.1698966666667,
text="Toolkit",
),
EmbeddedTextRegion.from_coords(
x1=1086.7866105555556,
y1=317.319341111111,
x2=1141.2105142777777,
y2=357.1698966666667,
text="for",
),
EmbeddedTextRegion.from_coords(
x1=1156.154472611111,
y1=317.319341111111,
x2=1256.334784222222,
y2=357.1698966666667,
text="Deep",
),
EmbeddedTextRegion.from_coords(
x1=437.83888888888885,
y1=367.13322999999986,
x2=610.0171992222222,
y2=406.9837855555556,
text="Learning",
),
EmbeddedTextRegion.from_coords(
x1=624.9611575555555,
y1=367.13322999999986,
x2=741.6754646666665,
y2=406.9837855555556,
text="Based",
),
EmbeddedTextRegion.from_coords(
x1=756.619423,
y1=367.13322999999986,
x2=958.3867708333332,
y2=406.9837855555556,
text="Document",
),
EmbeddedTextRegion.from_coords(
x1=973.3307291666665,
y1=367.13322999999986,
x2=1092.0535042777776,
y2=406.9837855555556,
text="Image",
),
]
def test_supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions): def test_supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions):
ocr_elements = [ ocr_elements = [
LayoutElement(text=r.text, source=None, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox) LayoutElement(text=r.text, source=None, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox)

View File

@ -1 +1 @@
__version__ = "0.12.4-dev5" # pragma: no cover __version__ = "0.12.4-dev6" # pragma: no cover

View File

@ -872,7 +872,6 @@ def convert_pdf_to_images(
yield image yield image
@requires_dependencies("unstructured_pytesseract", "unstructured_inference")
def _partition_pdf_or_image_with_ocr( def _partition_pdf_or_image_with_ocr(
filename: str = "", filename: str = "",
file: Optional[Union[bytes, IO[bytes]]] = None, file: Optional[Union[bytes, IO[bytes]]] = None,

View File

@ -0,0 +1,114 @@
from typing import TYPE_CHECKING, List, Optional, Union, cast
from unstructured_inference.constants import Source
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
partition_groups_from_regions,
)
from unstructured.documents.elements import ElementType
if TYPE_CHECKING:
from unstructured_inference.inference.elements import Rectangle
def build_text_region_from_coords(
x1: Union[int, float],
y1: Union[int, float],
x2: Union[int, float],
y2: Union[int, float],
text: Optional[str] = None,
source: Optional[Source] = None,
) -> TextRegion:
""""""
return TextRegion.from_coords(
x1,
y1,
x2,
y2,
text=text,
source=source,
)
def build_layout_element(
bbox: "Rectangle",
text: Optional[str] = None,
source: Optional[Source] = None,
element_type: Optional[str] = None,
) -> LayoutElement:
""""""
return LayoutElement(bbox=bbox, text=text, source=source, type=element_type)
def build_layout_elements_from_ocr_regions(
ocr_regions: List[TextRegion],
ocr_text: Optional[str] = None,
group_by_ocr_text: bool = False,
) -> List[LayoutElement]:
"""
Get layout elements from OCR regions
"""
if group_by_ocr_text:
text_sections = ocr_text.split("\n\n")
grouped_regions = []
for text_section in text_sections:
regions = []
words = text_section.replace("\n", " ").split()
for ocr_region in ocr_regions:
if not words:
break
if ocr_region.text in words:
regions.append(ocr_region)
words.remove(ocr_region.text)
if not regions:
continue
for r in regions:
ocr_regions.remove(r)
grouped_regions.append(regions)
else:
grouped_regions = cast(
List[List[TextRegion]],
partition_groups_from_regions(ocr_regions),
)
merged_regions = [merge_text_regions(group) for group in grouped_regions]
return [
build_layout_element(
bbox=r.bbox, text=r.text, source=r.source, element_type=ElementType.UNCATEGORIZED_TEXT
)
for r in merged_regions
]
def merge_text_regions(regions: List[TextRegion]) -> TextRegion:
"""
Merge a list of TextRegion objects into a single TextRegion.
Parameters:
- group (List[TextRegion]): A list of TextRegion objects to be merged.
Returns:
- TextRegion: A single merged TextRegion object.
"""
if not regions:
raise ValueError("The text regions to be merged must be provided.")
min_x1 = min([tr.bbox.x1 for tr in regions])
min_y1 = min([tr.bbox.y1 for tr in regions])
max_x2 = max([tr.bbox.x2 for tr in regions])
max_y2 = max([tr.bbox.y2 for tr in regions])
merged_text = " ".join([tr.text for tr in regions if tr.text])
sources = [tr.source for tr in regions]
source = sources[0] if all(s == sources[0] for s in sources) else None
return TextRegion.from_coords(min_x1, min_y1, max_x2, max_y2, merged_text, source)

View File

@ -1,6 +1,6 @@
import os import os
import tempfile import tempfile
from typing import BinaryIO, Dict, List, Optional, Union, cast from typing import TYPE_CHECKING, BinaryIO, Dict, List, Optional, Union, cast
import pdf2image import pdf2image
@ -8,12 +8,6 @@ import pdf2image
# unstructured.documents.elements.Image # unstructured.documents.elements.Image
from PIL import Image as PILImage from PIL import Image as PILImage
from PIL import ImageSequence from PIL import ImageSequence
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
from unstructured_inference.inference.layoutelement import (
LayoutElement,
)
from unstructured_inference.models.tables import UnstructuredTableTransformerModel
from unstructured.documents.elements import ElementType from unstructured.documents.elements import ElementType
from unstructured.logger import logger from unstructured.logger import logger
@ -29,8 +23,15 @@ from unstructured.partition.utils.constants import (
) )
from unstructured.partition.utils.ocr_models.ocr_interface import ( from unstructured.partition.utils.ocr_models.ocr_interface import (
OCRAgent, OCRAgent,
get_elements_from_ocr_regions,
) )
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured_inference.models.tables import UnstructuredTableTransformerModel
# Force tesseract to be single threaded, # Force tesseract to be single threaded,
# otherwise we see major performance problems # otherwise we see major performance problems
@ -91,6 +92,7 @@ def process_data_with_ocr(
return merged_layouts return merged_layouts
@requires_dependencies("unstructured_inference")
def process_file_with_ocr( def process_file_with_ocr(
filename: str, filename: str,
out_layout: "DocumentLayout", out_layout: "DocumentLayout",
@ -127,6 +129,9 @@ def process_file_with_ocr(
Returns: Returns:
DocumentLayout: The merged layout information obtained after OCR processing. DocumentLayout: The merged layout information obtained after OCR processing.
""" """
from unstructured_inference.inference.layout import DocumentLayout
merged_page_layouts = [] merged_page_layouts = []
try: try:
if is_image: if is_image:
@ -175,6 +180,7 @@ def process_file_with_ocr(
raise FileNotFoundError(f'File "{filename}" not found!') from e raise FileNotFoundError(f'File "{filename}" not found!') from e
@requires_dependencies("unstructured_inference")
def supplement_page_layout_with_ocr( def supplement_page_layout_with_ocr(
page_layout: "PageLayout", page_layout: "PageLayout",
image: PILImage, image: PILImage,
@ -198,7 +204,7 @@ def supplement_page_layout_with_ocr(
ocr_languages=ocr_languages, ocr_languages=ocr_languages,
) )
page_layout.elements[:] = merge_out_layout_with_ocr_layout( page_layout.elements[:] = merge_out_layout_with_ocr_layout(
out_layout=cast(List[LayoutElement], page_layout.elements), out_layout=cast(List["LayoutElement"], page_layout.elements),
ocr_layout=ocr_layout, ocr_layout=ocr_layout,
) )
elif ocr_mode == OCRMode.INDIVIDUAL_BLOCKS.value: elif ocr_mode == OCRMode.INDIVIDUAL_BLOCKS.value:
@ -236,7 +242,7 @@ def supplement_page_layout_with_ocr(
raise RuntimeError("Unable to load table extraction agent.") raise RuntimeError("Unable to load table extraction agent.")
page_layout.elements[:] = supplement_element_with_table_extraction( page_layout.elements[:] = supplement_element_with_table_extraction(
elements=cast(List[LayoutElement], page_layout.elements), elements=cast(List["LayoutElement"], page_layout.elements),
image=image, image=image,
tables_agent=tables.tables_agent, tables_agent=tables.tables_agent,
ocr_languages=ocr_languages, ocr_languages=ocr_languages,
@ -248,13 +254,13 @@ def supplement_page_layout_with_ocr(
def supplement_element_with_table_extraction( def supplement_element_with_table_extraction(
elements: List[LayoutElement], elements: List["LayoutElement"],
image: PILImage, image: PILImage,
tables_agent: "UnstructuredTableTransformerModel", tables_agent: "UnstructuredTableTransformerModel",
ocr_languages: str = "eng", ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT), ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
extracted_regions: Optional[List["TextRegion"]] = None, extracted_regions: Optional[List["TextRegion"]] = None,
) -> List[LayoutElement]: ) -> List["LayoutElement"]:
"""Supplement the existing layout with table extraction. Any Table elements """Supplement the existing layout with table extraction. Any Table elements
that are extracted will have a metadata field "text_as_html" where that are extracted will have a metadata field "text_as_html" where
the table's text content is rendered into an html string. the table's text content is rendered into an html string.
@ -324,10 +330,10 @@ def get_table_tokens(
def merge_out_layout_with_ocr_layout( def merge_out_layout_with_ocr_layout(
out_layout: List[LayoutElement], out_layout: List["LayoutElement"],
ocr_layout: List[TextRegion], ocr_layout: List["TextRegion"],
supplement_with_ocr_elements: bool = True, supplement_with_ocr_elements: bool = True,
) -> List[LayoutElement]: ) -> List["LayoutElement"]:
""" """
Merge the out layout with the OCR-detected text regions on page level. Merge the out layout with the OCR-detected text regions on page level.
@ -356,8 +362,8 @@ def merge_out_layout_with_ocr_layout(
def aggregate_ocr_text_by_block( def aggregate_ocr_text_by_block(
ocr_layout: List[TextRegion], ocr_layout: List["TextRegion"],
region: TextRegion, region: "TextRegion",
subregion_threshold: float, subregion_threshold: float,
) -> Optional[str]: ) -> Optional[str]:
"""Extracts the text aggregated from the regions of the ocr layout that lie within the given """Extracts the text aggregated from the regions of the ocr layout that lie within the given
@ -376,10 +382,11 @@ def aggregate_ocr_text_by_block(
return " ".join(extracted_texts) if extracted_texts else "" return " ".join(extracted_texts) if extracted_texts else ""
@requires_dependencies("unstructured_inference")
def supplement_layout_with_ocr_elements( def supplement_layout_with_ocr_elements(
layout: List[LayoutElement], layout: List["LayoutElement"],
ocr_layout: List[TextRegion], ocr_layout: List["TextRegion"],
) -> List[LayoutElement]: ) -> List["LayoutElement"]:
""" """
Supplement the existing layout with additional OCR-derived elements. Supplement the existing layout with additional OCR-derived elements.
@ -401,11 +408,16 @@ def supplement_layout_with_ocr_elements(
Note: Note:
- The function relies on `is_almost_subregion_of()` method to determine if an OCR region - The function relies on `is_almost_subregion_of()` method to determine if an OCR region
is a subregion of an existing layout element. is a subregion of an existing layout element.
- It also relies on `get_elements_from_ocr_regions()` to convert OCR regions to layout elements. - It also relies on `build_layout_elements_from_ocr_regions()` to convert OCR regions to
layout elements.
- The `SUBREGION_THRESHOLD_FOR_OCR` constant is used to specify the subregion matching - The `SUBREGION_THRESHOLD_FOR_OCR` constant is used to specify the subregion matching
threshold. threshold.
""" """
from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions,
)
ocr_regions_to_remove = [] ocr_regions_to_remove = []
for ocr_region in ocr_layout: for ocr_region in ocr_layout:
for el in layout: for el in layout:
@ -419,7 +431,7 @@ def supplement_layout_with_ocr_elements(
ocr_regions_to_add = [region for region in ocr_layout if region not in ocr_regions_to_remove] ocr_regions_to_add = [region for region in ocr_layout if region not in ocr_regions_to_remove]
if ocr_regions_to_add: if ocr_regions_to_add:
ocr_elements_to_add = get_elements_from_ocr_regions(ocr_regions_to_add) ocr_elements_to_add = build_layout_elements_from_ocr_regions(ocr_regions_to_add)
final_layout = layout + ocr_elements_to_add final_layout = layout + ocr_elements_to_add
else: else:
final_layout = layout final_layout = layout
@ -427,7 +439,7 @@ def supplement_layout_with_ocr_elements(
return final_layout return final_layout
def get_ocr_agent() -> str: def get_ocr_agent() -> OCRAgent:
ocr_agent_module = env_config.OCR_AGENT ocr_agent_module = env_config.OCR_AGENT
message = ( message = (
"OCR agent name %s is outdated and will be deprecated in a future release; please use %s " "OCR agent name %s is outdated and will be deprecated in a future release; please use %s "

View File

@ -1,18 +1,17 @@
import functools import functools
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, cast from typing import TYPE_CHECKING, Any, List
from PIL import Image as PILImage from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
if TYPE_CHECKING:
from PIL import PILImage
from unstructured_inference.inference.elements import TextRegion from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import ( from unstructured_inference.inference.layoutelement import (
LayoutElement, LayoutElement,
partition_groups_from_regions,
) )
from unstructured.documents.elements import ElementType
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
class OCRAgent(ABC): class OCRAgent(ABC):
def __init__(self): def __init__(self):
@ -27,19 +26,19 @@ class OCRAgent(ABC):
pass pass
@abstractmethod @abstractmethod
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str: def get_text_from_image(self, image: "PILImage", ocr_languages: str = "eng") -> str:
pass pass
@abstractmethod @abstractmethod
def get_layout_from_image( def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: "PILImage", ocr_languages: str = "eng"
) -> List[TextRegion]: ) -> List["TextRegion"]:
pass pass
@abstractmethod @abstractmethod
def get_layout_elements_from_image( def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: "PILImage", ocr_languages: str = "eng"
) -> List[LayoutElement]: ) -> List["LayoutElement"]:
pass pass
@staticmethod @staticmethod
@ -55,73 +54,3 @@ class OCRAgent(ABC):
f"Environment variable OCR_AGENT module name {module_name}", f"Environment variable OCR_AGENT module name {module_name}",
f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.", f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
) )
def get_elements_from_ocr_regions(
ocr_regions: List[TextRegion],
ocr_text: Optional[str] = None,
group_by_ocr_text: bool = False,
) -> List[LayoutElement]:
"""
Get layout elements from OCR regions
"""
if group_by_ocr_text:
text_sections = ocr_text.split("\n\n")
grouped_regions = []
for text_section in text_sections:
regions = []
words = text_section.replace("\n", " ").split()
for ocr_region in ocr_regions:
if not words:
break
if ocr_region.text in words:
regions.append(ocr_region)
words.remove(ocr_region.text)
if not regions:
continue
for r in regions:
ocr_regions.remove(r)
grouped_regions.append(regions)
else:
grouped_regions = cast(
List[List[TextRegion]],
partition_groups_from_regions(ocr_regions),
)
merged_regions = [merge_text_regions(group) for group in grouped_regions]
return [
LayoutElement(
text=r.text, source=r.source, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox
)
for r in merged_regions
]
def merge_text_regions(regions: List[TextRegion]) -> TextRegion:
"""
Merge a list of TextRegion objects into a single TextRegion.
Parameters:
- group (List[TextRegion]): A list of TextRegion objects to be merged.
Returns:
- TextRegion: A single merged TextRegion object.
"""
if not regions:
raise ValueError("The text regions to be merged must be provided.")
min_x1 = min([tr.bbox.x1 for tr in regions])
min_y1 = min([tr.bbox.y1 for tr in regions])
max_x2 = max([tr.bbox.x2 for tr in regions])
max_y2 = max([tr.bbox.y2 for tr in regions])
merged_text = " ".join([tr.text for tr in regions if tr.text])
sources = [tr.source for tr in regions]
source = sources[0] if all(s == sources[0] for s in sources) else None
return TextRegion.from_coords(min_x1, min_y1, max_x2, max_y2, merged_text, source)

View File

@ -1,11 +1,7 @@
from typing import List from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
from PIL import Image as PILImage from PIL import Image as PILImage
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
)
from unstructured.documents.elements import ElementType from unstructured.documents.elements import ElementType
from unstructured.logger import logger from unstructured.logger import logger
@ -14,6 +10,11 @@ from unstructured.partition.utils.constants import (
Source, Source,
) )
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import LayoutElement
class OCRAgentPaddle(OCRAgent): class OCRAgentPaddle(OCRAgent):
@ -63,7 +64,7 @@ class OCRAgentPaddle(OCRAgent):
def get_layout_from_image( def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage, ocr_languages: str = "eng"
) -> List[TextRegion]: ) -> List["TextRegion"]:
"""Get the OCR regions from image as a list of text regions with paddle.""" """Get the OCR regions from image as a list of text regions with paddle."""
logger.info("Processing entire page OCR with paddle...") logger.info("Processing entire page OCR with paddle...")
@ -76,9 +77,12 @@ class OCRAgentPaddle(OCRAgent):
return ocr_regions return ocr_regions
@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image( def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage, ocr_languages: str = "eng"
) -> List[LayoutElement]: ) -> List["LayoutElement"]:
from unstructured.partition.pdf_image.inference_utils import build_layout_element
ocr_regions = self.get_layout_from_image( ocr_regions = self.get_layout_from_image(
image, image,
ocr_languages=ocr_languages, ocr_languages=ocr_languages,
@ -88,13 +92,17 @@ class OCRAgentPaddle(OCRAgent):
# terms of grouping because we get ocr_text from `ocr_layout, so the first two grouping # terms of grouping because we get ocr_text from `ocr_layout, so the first two grouping
# and merging steps are not necessary. # and merging steps are not necessary.
return [ return [
LayoutElement( build_layout_element(
bbox=r.bbox, text=r.text, source=r.source, type=ElementType.UNCATEGORIZED_TEXT bbox=r.bbox,
text=r.text,
source=r.source,
element_type=ElementType.UNCATEGORIZED_TEXT,
) )
for r in ocr_regions for r in ocr_regions
] ]
def parse_data(self, ocr_data: list) -> List[TextRegion]: @requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data: list) -> List["TextRegion"]:
""" """
Parse the OCR result data to extract a list of TextRegion objects from Parse the OCR result data to extract a list of TextRegion objects from
paddle. paddle.
@ -114,6 +122,9 @@ class OCRAgentPaddle(OCRAgent):
- An empty string or a None value for the 'text' key in the input - An empty string or a None value for the 'text' key in the input
dictionary will result in its associated bounding box being ignored. dictionary will result in its associated bounding box being ignored.
""" """
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
text_regions = [] text_regions = []
for idx in range(len(ocr_data)): for idx in range(len(ocr_data)):
res = ocr_data[idx] res = ocr_data[idx]
@ -130,12 +141,12 @@ class OCRAgentPaddle(OCRAgent):
continue continue
cleaned_text = text.strip() cleaned_text = text.strip()
if cleaned_text: if cleaned_text:
text_region = TextRegion.from_coords( text_region = build_text_region_from_coords(
x1, x1,
y1, y1,
x2, x2,
y2, y2,
cleaned_text, text=cleaned_text,
source=Source.OCR_PADDLE, source=Source.OCR_PADDLE,
) )
text_regions.append(text_region) text_regions.append(text_region)

View File

@ -1,18 +1,13 @@
from typing import List from typing import TYPE_CHECKING, List
import cv2 import cv2
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import unstructured_pytesseract import unstructured_pytesseract
from PIL import Image as PILImage from PIL import Image as PILImage
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
)
from unstructured_pytesseract import Output from unstructured_pytesseract import Output
from unstructured.logger import logger from unstructured.logger import logger
from unstructured.partition.pdf_image.ocr import get_elements_from_ocr_regions
from unstructured.partition.utils.config import env_config from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import ( from unstructured.partition.utils.constants import (
IMAGE_COLOR_DEPTH, IMAGE_COLOR_DEPTH,
@ -21,6 +16,13 @@ from unstructured.partition.utils.constants import (
Source, Source,
) )
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
)
class OCRAgentTesseract(OCRAgent): class OCRAgentTesseract(OCRAgent):
@ -38,7 +40,7 @@ class OCRAgentTesseract(OCRAgent):
def get_layout_from_image( def get_layout_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage, ocr_languages: str = "eng"
) -> List[TextRegion]: ) -> List["TextRegion"]:
"""Get the OCR regions from image as a list of text regions with tesseract.""" """Get the OCR regions from image as a list of text regions with tesseract."""
logger.info("Processing entire page OCR with tesseract...") logger.info("Processing entire page OCR with tesseract...")
@ -83,9 +85,14 @@ class OCRAgentTesseract(OCRAgent):
return ocr_regions return ocr_regions
@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image( def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng" self, image: PILImage, ocr_languages: str = "eng"
) -> List[LayoutElement]: ) -> List["LayoutElement"]:
from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions,
)
ocr_regions = self.get_layout_from_image( ocr_regions = self.get_layout_from_image(
image, image,
ocr_languages=ocr_languages, ocr_languages=ocr_languages,
@ -103,13 +110,14 @@ class OCRAgentTesseract(OCRAgent):
ocr_languages=ocr_languages, ocr_languages=ocr_languages,
) )
return get_elements_from_ocr_regions( return build_layout_elements_from_ocr_regions(
ocr_regions=ocr_regions, ocr_regions=ocr_regions,
ocr_text=ocr_text, ocr_text=ocr_text,
group_by_ocr_text=True, group_by_ocr_text=True,
) )
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List[TextRegion]: @requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
""" """
Parse the OCR result data to extract a list of TextRegion objects from Parse the OCR result data to extract a list of TextRegion objects from
tesseract. tesseract.
@ -137,6 +145,8 @@ class OCRAgentTesseract(OCRAgent):
data frame will result in its associated bounding box being ignored. data frame will result in its associated bounding box being ignored.
""" """
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
if zoom <= 0: if zoom <= 0:
zoom = 1 zoom = 1
@ -153,7 +163,7 @@ class OCRAgentTesseract(OCRAgent):
y1 = idtx.top / zoom y1 = idtx.top / zoom
x2 = (idtx.left + idtx.width) / zoom x2 = (idtx.left + idtx.width) / zoom
y2 = (idtx.top + idtx.height) / zoom y2 = (idtx.top + idtx.height) / zoom
text_region = TextRegion.from_coords( text_region = build_text_region_from_coords(
x1, x1,
y1, y1,
x2, x2,

View File

@ -1,11 +1,13 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING
from unstructured_inference.inference.layout import DocumentLayout
from unstructured.partition.utils.constants import Source from unstructured.partition.utils.constants import Source
if TYPE_CHECKING:
from unstructured_inference.inference.layout import DocumentLayout
def clean_pdfminer_inner_elements(document: DocumentLayout) -> DocumentLayout:
def clean_pdfminer_inner_elements(document: "DocumentLayout") -> "DocumentLayout":
"""Clean pdfminer elements from inside tables and stores them in extra_info dictionary """Clean pdfminer elements from inside tables and stores them in extra_info dictionary
with the table id as key""" with the table id as key"""
defaultdict(list) defaultdict(list)