mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-07-13 20:15:54 +00:00
refactor: ocr modules (#2492)
The purpose of this PR is to refactor OCR-related modules to reduce unnecessary module imports to avoid potential issues (most likely due to a "circular import"). ### Summary - add `inference_utils` module (unstructured/partition/pdf_image/inference_utils.py) to define unstructured-inference library related utility functions, which will reduce importing unstructured-inference library functions in other files - add `conftest.py` in `test_unstructured/partition/pdf_image/` directory to define fixtures that are available to all tests in the same directory and its subdirectories ### Testing CI should pass
This commit is contained in:
parent
0f0b58dfe7
commit
29b9ea7ba6
@ -1,4 +1,4 @@
|
||||
## 0.12.4-dev5
|
||||
## 0.12.4-dev6
|
||||
|
||||
### Enhancements
|
||||
|
||||
|
78
test_unstructured/partition/pdf_image/conftest.py
Normal file
78
test_unstructured/partition/pdf_image/conftest.py
Normal file
@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
from unstructured_inference.inference.elements import EmbeddedTextRegion
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_embedded_text_regions():
|
||||
return [
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=453.00277777777774,
|
||||
y1=317.319341111111,
|
||||
x2=711.5338541666665,
|
||||
y2=358.28571222222206,
|
||||
text="LayoutParser:",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=726.4778125,
|
||||
y1=317.319341111111,
|
||||
x2=760.3308594444444,
|
||||
y2=357.1698966666667,
|
||||
text="A",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=775.2748177777777,
|
||||
y1=317.319341111111,
|
||||
x2=917.3579885555555,
|
||||
y2=357.1698966666667,
|
||||
text="Unified",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=932.3019468888888,
|
||||
y1=317.319341111111,
|
||||
x2=1071.8426522222221,
|
||||
y2=357.1698966666667,
|
||||
text="Toolkit",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=1086.7866105555556,
|
||||
y1=317.319341111111,
|
||||
x2=1141.2105142777777,
|
||||
y2=357.1698966666667,
|
||||
text="for",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=1156.154472611111,
|
||||
y1=317.319341111111,
|
||||
x2=1256.334784222222,
|
||||
y2=357.1698966666667,
|
||||
text="Deep",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=437.83888888888885,
|
||||
y1=367.13322999999986,
|
||||
x2=610.0171992222222,
|
||||
y2=406.9837855555556,
|
||||
text="Learning",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=624.9611575555555,
|
||||
y1=367.13322999999986,
|
||||
x2=741.6754646666665,
|
||||
y2=406.9837855555556,
|
||||
text="Based",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=756.619423,
|
||||
y1=367.13322999999986,
|
||||
x2=958.3867708333332,
|
||||
y2=406.9837855555556,
|
||||
text="Document",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=973.3307291666665,
|
||||
y1=367.13322999999986,
|
||||
x2=1092.0535042777776,
|
||||
y2=406.9837855555556,
|
||||
text="Image",
|
||||
),
|
||||
]
|
@ -0,0 +1,37 @@
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import LayoutElement
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.partition.pdf_image.inference_utils import (
|
||||
build_layout_elements_from_ocr_regions,
|
||||
merge_text_regions,
|
||||
)
|
||||
|
||||
|
||||
def test_merge_text_regions(mock_embedded_text_regions):
|
||||
expected = TextRegion.from_coords(
|
||||
x1=437.83888888888885,
|
||||
y1=317.319341111111,
|
||||
x2=1256.334784222222,
|
||||
y2=406.9837855555556,
|
||||
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
|
||||
)
|
||||
|
||||
merged_text_region = merge_text_regions(mock_embedded_text_regions)
|
||||
assert merged_text_region == expected
|
||||
|
||||
|
||||
def test_build_layout_elements_from_ocr_regions(mock_embedded_text_regions):
|
||||
expected = [
|
||||
LayoutElement.from_coords(
|
||||
x1=437.83888888888885,
|
||||
y1=317.319341111111,
|
||||
x2=1256.334784222222,
|
||||
y2=406.9837855555556,
|
||||
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
|
||||
type=ElementType.UNCATEGORIZED_TEXT,
|
||||
),
|
||||
]
|
||||
|
||||
elements = build_layout_elements_from_ocr_regions(mock_embedded_text_regions)
|
||||
assert elements == expected
|
@ -18,10 +18,6 @@ from unstructured.partition.pdf_image.ocr import pad_element_bboxes
|
||||
from unstructured.partition.utils.constants import (
|
||||
Source,
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import (
|
||||
get_elements_from_ocr_regions,
|
||||
merge_text_regions,
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
|
||||
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
|
||||
OCRAgentTesseract,
|
||||
@ -231,35 +227,6 @@ def test_aggregate_ocr_text_by_block():
|
||||
assert text == expected
|
||||
|
||||
|
||||
def test_merge_text_regions(mock_embedded_text_regions):
|
||||
expected = TextRegion.from_coords(
|
||||
x1=437.83888888888885,
|
||||
y1=317.319341111111,
|
||||
x2=1256.334784222222,
|
||||
y2=406.9837855555556,
|
||||
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
|
||||
)
|
||||
|
||||
merged_text_region = merge_text_regions(mock_embedded_text_regions)
|
||||
assert merged_text_region == expected
|
||||
|
||||
|
||||
def test_get_elements_from_ocr_regions(mock_embedded_text_regions):
|
||||
expected = [
|
||||
LayoutElement.from_coords(
|
||||
x1=437.83888888888885,
|
||||
y1=317.319341111111,
|
||||
x2=1256.334784222222,
|
||||
y2=406.9837855555556,
|
||||
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
|
||||
type=ElementType.UNCATEGORIZED_TEXT,
|
||||
),
|
||||
]
|
||||
|
||||
elements = get_elements_from_ocr_regions(mock_embedded_text_regions)
|
||||
assert elements == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0])
|
||||
def test_zoom_image(zoom):
|
||||
image = Image.new("RGB", (100, 100))
|
||||
@ -280,82 +247,6 @@ def mock_layout(mock_embedded_text_regions):
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_embedded_text_regions():
|
||||
return [
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=453.00277777777774,
|
||||
y1=317.319341111111,
|
||||
x2=711.5338541666665,
|
||||
y2=358.28571222222206,
|
||||
text="LayoutParser:",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=726.4778125,
|
||||
y1=317.319341111111,
|
||||
x2=760.3308594444444,
|
||||
y2=357.1698966666667,
|
||||
text="A",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=775.2748177777777,
|
||||
y1=317.319341111111,
|
||||
x2=917.3579885555555,
|
||||
y2=357.1698966666667,
|
||||
text="Unified",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=932.3019468888888,
|
||||
y1=317.319341111111,
|
||||
x2=1071.8426522222221,
|
||||
y2=357.1698966666667,
|
||||
text="Toolkit",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=1086.7866105555556,
|
||||
y1=317.319341111111,
|
||||
x2=1141.2105142777777,
|
||||
y2=357.1698966666667,
|
||||
text="for",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=1156.154472611111,
|
||||
y1=317.319341111111,
|
||||
x2=1256.334784222222,
|
||||
y2=357.1698966666667,
|
||||
text="Deep",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=437.83888888888885,
|
||||
y1=367.13322999999986,
|
||||
x2=610.0171992222222,
|
||||
y2=406.9837855555556,
|
||||
text="Learning",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=624.9611575555555,
|
||||
y1=367.13322999999986,
|
||||
x2=741.6754646666665,
|
||||
y2=406.9837855555556,
|
||||
text="Based",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=756.619423,
|
||||
y1=367.13322999999986,
|
||||
x2=958.3867708333332,
|
||||
y2=406.9837855555556,
|
||||
text="Document",
|
||||
),
|
||||
EmbeddedTextRegion.from_coords(
|
||||
x1=973.3307291666665,
|
||||
y1=367.13322999999986,
|
||||
x2=1092.0535042777776,
|
||||
y2=406.9837855555556,
|
||||
text="Image",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions):
|
||||
ocr_elements = [
|
||||
LayoutElement(text=r.text, source=None, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox)
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.12.4-dev5" # pragma: no cover
|
||||
__version__ = "0.12.4-dev6" # pragma: no cover
|
||||
|
@ -872,7 +872,6 @@ def convert_pdf_to_images(
|
||||
yield image
|
||||
|
||||
|
||||
@requires_dependencies("unstructured_pytesseract", "unstructured_inference")
|
||||
def _partition_pdf_or_image_with_ocr(
|
||||
filename: str = "",
|
||||
file: Optional[Union[bytes, IO[bytes]]] = None,
|
||||
|
114
unstructured/partition/pdf_image/inference_utils.py
Normal file
114
unstructured/partition/pdf_image/inference_utils.py
Normal file
@ -0,0 +1,114 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||
|
||||
from unstructured_inference.constants import Source
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
partition_groups_from_regions,
|
||||
)
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_inference.inference.elements import Rectangle
|
||||
|
||||
|
||||
def build_text_region_from_coords(
|
||||
x1: Union[int, float],
|
||||
y1: Union[int, float],
|
||||
x2: Union[int, float],
|
||||
y2: Union[int, float],
|
||||
text: Optional[str] = None,
|
||||
source: Optional[Source] = None,
|
||||
) -> TextRegion:
|
||||
""""""
|
||||
|
||||
return TextRegion.from_coords(
|
||||
x1,
|
||||
y1,
|
||||
x2,
|
||||
y2,
|
||||
text=text,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def build_layout_element(
|
||||
bbox: "Rectangle",
|
||||
text: Optional[str] = None,
|
||||
source: Optional[Source] = None,
|
||||
element_type: Optional[str] = None,
|
||||
) -> LayoutElement:
|
||||
""""""
|
||||
|
||||
return LayoutElement(bbox=bbox, text=text, source=source, type=element_type)
|
||||
|
||||
|
||||
def build_layout_elements_from_ocr_regions(
|
||||
ocr_regions: List[TextRegion],
|
||||
ocr_text: Optional[str] = None,
|
||||
group_by_ocr_text: bool = False,
|
||||
) -> List[LayoutElement]:
|
||||
"""
|
||||
Get layout elements from OCR regions
|
||||
"""
|
||||
|
||||
if group_by_ocr_text:
|
||||
text_sections = ocr_text.split("\n\n")
|
||||
grouped_regions = []
|
||||
for text_section in text_sections:
|
||||
regions = []
|
||||
words = text_section.replace("\n", " ").split()
|
||||
for ocr_region in ocr_regions:
|
||||
if not words:
|
||||
break
|
||||
if ocr_region.text in words:
|
||||
regions.append(ocr_region)
|
||||
words.remove(ocr_region.text)
|
||||
|
||||
if not regions:
|
||||
continue
|
||||
|
||||
for r in regions:
|
||||
ocr_regions.remove(r)
|
||||
|
||||
grouped_regions.append(regions)
|
||||
else:
|
||||
grouped_regions = cast(
|
||||
List[List[TextRegion]],
|
||||
partition_groups_from_regions(ocr_regions),
|
||||
)
|
||||
|
||||
merged_regions = [merge_text_regions(group) for group in grouped_regions]
|
||||
return [
|
||||
build_layout_element(
|
||||
bbox=r.bbox, text=r.text, source=r.source, element_type=ElementType.UNCATEGORIZED_TEXT
|
||||
)
|
||||
for r in merged_regions
|
||||
]
|
||||
|
||||
|
||||
def merge_text_regions(regions: List[TextRegion]) -> TextRegion:
|
||||
"""
|
||||
Merge a list of TextRegion objects into a single TextRegion.
|
||||
|
||||
Parameters:
|
||||
- group (List[TextRegion]): A list of TextRegion objects to be merged.
|
||||
|
||||
Returns:
|
||||
- TextRegion: A single merged TextRegion object.
|
||||
"""
|
||||
|
||||
if not regions:
|
||||
raise ValueError("The text regions to be merged must be provided.")
|
||||
|
||||
min_x1 = min([tr.bbox.x1 for tr in regions])
|
||||
min_y1 = min([tr.bbox.y1 for tr in regions])
|
||||
max_x2 = max([tr.bbox.x2 for tr in regions])
|
||||
max_y2 = max([tr.bbox.y2 for tr in regions])
|
||||
|
||||
merged_text = " ".join([tr.text for tr in regions if tr.text])
|
||||
sources = [tr.source for tr in regions]
|
||||
source = sources[0] if all(s == sources[0] for s in sources) else None
|
||||
|
||||
return TextRegion.from_coords(min_x1, min_y1, max_x2, max_y2, merged_text, source)
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import BinaryIO, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, BinaryIO, Dict, List, Optional, Union, cast
|
||||
|
||||
import pdf2image
|
||||
|
||||
@ -8,12 +8,6 @@ import pdf2image
|
||||
# unstructured.documents.elements.Image
|
||||
from PIL import Image as PILImage
|
||||
from PIL import ImageSequence
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
)
|
||||
from unstructured_inference.models.tables import UnstructuredTableTransformerModel
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.logger import logger
|
||||
@ -29,8 +23,15 @@ from unstructured.partition.utils.constants import (
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import (
|
||||
OCRAgent,
|
||||
get_elements_from_ocr_regions,
|
||||
)
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
|
||||
from unstructured_inference.inference.layoutelement import LayoutElement
|
||||
from unstructured_inference.models.tables import UnstructuredTableTransformerModel
|
||||
|
||||
|
||||
# Force tesseract to be single threaded,
|
||||
# otherwise we see major performance problems
|
||||
@ -91,6 +92,7 @@ def process_data_with_ocr(
|
||||
return merged_layouts
|
||||
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def process_file_with_ocr(
|
||||
filename: str,
|
||||
out_layout: "DocumentLayout",
|
||||
@ -127,6 +129,9 @@ def process_file_with_ocr(
|
||||
Returns:
|
||||
DocumentLayout: The merged layout information obtained after OCR processing.
|
||||
"""
|
||||
|
||||
from unstructured_inference.inference.layout import DocumentLayout
|
||||
|
||||
merged_page_layouts = []
|
||||
try:
|
||||
if is_image:
|
||||
@ -175,6 +180,7 @@ def process_file_with_ocr(
|
||||
raise FileNotFoundError(f'File "{filename}" not found!') from e
|
||||
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def supplement_page_layout_with_ocr(
|
||||
page_layout: "PageLayout",
|
||||
image: PILImage,
|
||||
@ -198,7 +204,7 @@ def supplement_page_layout_with_ocr(
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
page_layout.elements[:] = merge_out_layout_with_ocr_layout(
|
||||
out_layout=cast(List[LayoutElement], page_layout.elements),
|
||||
out_layout=cast(List["LayoutElement"], page_layout.elements),
|
||||
ocr_layout=ocr_layout,
|
||||
)
|
||||
elif ocr_mode == OCRMode.INDIVIDUAL_BLOCKS.value:
|
||||
@ -236,7 +242,7 @@ def supplement_page_layout_with_ocr(
|
||||
raise RuntimeError("Unable to load table extraction agent.")
|
||||
|
||||
page_layout.elements[:] = supplement_element_with_table_extraction(
|
||||
elements=cast(List[LayoutElement], page_layout.elements),
|
||||
elements=cast(List["LayoutElement"], page_layout.elements),
|
||||
image=image,
|
||||
tables_agent=tables.tables_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
@ -248,13 +254,13 @@ def supplement_page_layout_with_ocr(
|
||||
|
||||
|
||||
def supplement_element_with_table_extraction(
|
||||
elements: List[LayoutElement],
|
||||
elements: List["LayoutElement"],
|
||||
image: PILImage,
|
||||
tables_agent: "UnstructuredTableTransformerModel",
|
||||
ocr_languages: str = "eng",
|
||||
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
|
||||
extracted_regions: Optional[List["TextRegion"]] = None,
|
||||
) -> List[LayoutElement]:
|
||||
) -> List["LayoutElement"]:
|
||||
"""Supplement the existing layout with table extraction. Any Table elements
|
||||
that are extracted will have a metadata field "text_as_html" where
|
||||
the table's text content is rendered into an html string.
|
||||
@ -324,10 +330,10 @@ def get_table_tokens(
|
||||
|
||||
|
||||
def merge_out_layout_with_ocr_layout(
|
||||
out_layout: List[LayoutElement],
|
||||
ocr_layout: List[TextRegion],
|
||||
out_layout: List["LayoutElement"],
|
||||
ocr_layout: List["TextRegion"],
|
||||
supplement_with_ocr_elements: bool = True,
|
||||
) -> List[LayoutElement]:
|
||||
) -> List["LayoutElement"]:
|
||||
"""
|
||||
Merge the out layout with the OCR-detected text regions on page level.
|
||||
|
||||
@ -356,8 +362,8 @@ def merge_out_layout_with_ocr_layout(
|
||||
|
||||
|
||||
def aggregate_ocr_text_by_block(
|
||||
ocr_layout: List[TextRegion],
|
||||
region: TextRegion,
|
||||
ocr_layout: List["TextRegion"],
|
||||
region: "TextRegion",
|
||||
subregion_threshold: float,
|
||||
) -> Optional[str]:
|
||||
"""Extracts the text aggregated from the regions of the ocr layout that lie within the given
|
||||
@ -376,10 +382,11 @@ def aggregate_ocr_text_by_block(
|
||||
return " ".join(extracted_texts) if extracted_texts else ""
|
||||
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def supplement_layout_with_ocr_elements(
|
||||
layout: List[LayoutElement],
|
||||
ocr_layout: List[TextRegion],
|
||||
) -> List[LayoutElement]:
|
||||
layout: List["LayoutElement"],
|
||||
ocr_layout: List["TextRegion"],
|
||||
) -> List["LayoutElement"]:
|
||||
"""
|
||||
Supplement the existing layout with additional OCR-derived elements.
|
||||
|
||||
@ -401,11 +408,16 @@ def supplement_layout_with_ocr_elements(
|
||||
Note:
|
||||
- The function relies on `is_almost_subregion_of()` method to determine if an OCR region
|
||||
is a subregion of an existing layout element.
|
||||
- It also relies on `get_elements_from_ocr_regions()` to convert OCR regions to layout elements.
|
||||
- It also relies on `build_layout_elements_from_ocr_regions()` to convert OCR regions to
|
||||
layout elements.
|
||||
- The `SUBREGION_THRESHOLD_FOR_OCR` constant is used to specify the subregion matching
|
||||
threshold.
|
||||
"""
|
||||
|
||||
from unstructured.partition.pdf_image.inference_utils import (
|
||||
build_layout_elements_from_ocr_regions,
|
||||
)
|
||||
|
||||
ocr_regions_to_remove = []
|
||||
for ocr_region in ocr_layout:
|
||||
for el in layout:
|
||||
@ -419,7 +431,7 @@ def supplement_layout_with_ocr_elements(
|
||||
|
||||
ocr_regions_to_add = [region for region in ocr_layout if region not in ocr_regions_to_remove]
|
||||
if ocr_regions_to_add:
|
||||
ocr_elements_to_add = get_elements_from_ocr_regions(ocr_regions_to_add)
|
||||
ocr_elements_to_add = build_layout_elements_from_ocr_regions(ocr_regions_to_add)
|
||||
final_layout = layout + ocr_elements_to_add
|
||||
else:
|
||||
final_layout = layout
|
||||
@ -427,7 +439,7 @@ def supplement_layout_with_ocr_elements(
|
||||
return final_layout
|
||||
|
||||
|
||||
def get_ocr_agent() -> str:
|
||||
def get_ocr_agent() -> OCRAgent:
|
||||
ocr_agent_module = env_config.OCR_AGENT
|
||||
message = (
|
||||
"OCR agent name %s is outdated and will be deprecated in a future release; please use %s "
|
||||
|
@ -1,18 +1,17 @@
|
||||
import functools
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
partition_groups_from_regions,
|
||||
)
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import PILImage
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
)
|
||||
|
||||
|
||||
class OCRAgent(ABC):
|
||||
def __init__(self):
|
||||
@ -27,19 +26,19 @@ class OCRAgent(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> str:
|
||||
def get_text_from_image(self, image: "PILImage", ocr_languages: str = "eng") -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List[TextRegion]:
|
||||
self, image: "PILImage", ocr_languages: str = "eng"
|
||||
) -> List["TextRegion"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List[LayoutElement]:
|
||||
self, image: "PILImage", ocr_languages: str = "eng"
|
||||
) -> List["LayoutElement"]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -55,73 +54,3 @@ class OCRAgent(ABC):
|
||||
f"Environment variable OCR_AGENT module name {module_name}",
|
||||
f" must be set to a whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
|
||||
)
|
||||
|
||||
|
||||
def get_elements_from_ocr_regions(
|
||||
ocr_regions: List[TextRegion],
|
||||
ocr_text: Optional[str] = None,
|
||||
group_by_ocr_text: bool = False,
|
||||
) -> List[LayoutElement]:
|
||||
"""
|
||||
Get layout elements from OCR regions
|
||||
"""
|
||||
|
||||
if group_by_ocr_text:
|
||||
text_sections = ocr_text.split("\n\n")
|
||||
grouped_regions = []
|
||||
for text_section in text_sections:
|
||||
regions = []
|
||||
words = text_section.replace("\n", " ").split()
|
||||
for ocr_region in ocr_regions:
|
||||
if not words:
|
||||
break
|
||||
if ocr_region.text in words:
|
||||
regions.append(ocr_region)
|
||||
words.remove(ocr_region.text)
|
||||
|
||||
if not regions:
|
||||
continue
|
||||
|
||||
for r in regions:
|
||||
ocr_regions.remove(r)
|
||||
|
||||
grouped_regions.append(regions)
|
||||
else:
|
||||
grouped_regions = cast(
|
||||
List[List[TextRegion]],
|
||||
partition_groups_from_regions(ocr_regions),
|
||||
)
|
||||
|
||||
merged_regions = [merge_text_regions(group) for group in grouped_regions]
|
||||
return [
|
||||
LayoutElement(
|
||||
text=r.text, source=r.source, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox
|
||||
)
|
||||
for r in merged_regions
|
||||
]
|
||||
|
||||
|
||||
def merge_text_regions(regions: List[TextRegion]) -> TextRegion:
|
||||
"""
|
||||
Merge a list of TextRegion objects into a single TextRegion.
|
||||
|
||||
Parameters:
|
||||
- group (List[TextRegion]): A list of TextRegion objects to be merged.
|
||||
|
||||
Returns:
|
||||
- TextRegion: A single merged TextRegion object.
|
||||
"""
|
||||
|
||||
if not regions:
|
||||
raise ValueError("The text regions to be merged must be provided.")
|
||||
|
||||
min_x1 = min([tr.bbox.x1 for tr in regions])
|
||||
min_y1 = min([tr.bbox.y1 for tr in regions])
|
||||
max_x2 = max([tr.bbox.x2 for tr in regions])
|
||||
max_y2 = max([tr.bbox.y2 for tr in regions])
|
||||
|
||||
merged_text = " ".join([tr.text for tr in regions if tr.text])
|
||||
sources = [tr.source for tr in regions]
|
||||
source = sources[0] if all(s == sources[0] for s in sources) else None
|
||||
|
||||
return TextRegion.from_coords(min_x1, min_y1, max_x2, max_y2, merged_text, source)
|
||||
|
@ -1,11 +1,7 @@
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
)
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.logger import logger
|
||||
@ -14,6 +10,11 @@ from unstructured.partition.utils.constants import (
|
||||
Source,
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import LayoutElement
|
||||
|
||||
|
||||
class OCRAgentPaddle(OCRAgent):
|
||||
@ -63,7 +64,7 @@ class OCRAgentPaddle(OCRAgent):
|
||||
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List[TextRegion]:
|
||||
) -> List["TextRegion"]:
|
||||
"""Get the OCR regions from image as a list of text regions with paddle."""
|
||||
|
||||
logger.info("Processing entire page OCR with paddle...")
|
||||
@ -76,9 +77,12 @@ class OCRAgentPaddle(OCRAgent):
|
||||
|
||||
return ocr_regions
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List[LayoutElement]:
|
||||
) -> List["LayoutElement"]:
|
||||
from unstructured.partition.pdf_image.inference_utils import build_layout_element
|
||||
|
||||
ocr_regions = self.get_layout_from_image(
|
||||
image,
|
||||
ocr_languages=ocr_languages,
|
||||
@ -88,13 +92,17 @@ class OCRAgentPaddle(OCRAgent):
|
||||
# terms of grouping because we get ocr_text from `ocr_layout, so the first two grouping
|
||||
# and merging steps are not necessary.
|
||||
return [
|
||||
LayoutElement(
|
||||
bbox=r.bbox, text=r.text, source=r.source, type=ElementType.UNCATEGORIZED_TEXT
|
||||
build_layout_element(
|
||||
bbox=r.bbox,
|
||||
text=r.text,
|
||||
source=r.source,
|
||||
element_type=ElementType.UNCATEGORIZED_TEXT,
|
||||
)
|
||||
for r in ocr_regions
|
||||
]
|
||||
|
||||
def parse_data(self, ocr_data: list) -> List[TextRegion]:
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def parse_data(self, ocr_data: list) -> List["TextRegion"]:
|
||||
"""
|
||||
Parse the OCR result data to extract a list of TextRegion objects from
|
||||
paddle.
|
||||
@ -114,6 +122,9 @@ class OCRAgentPaddle(OCRAgent):
|
||||
- An empty string or a None value for the 'text' key in the input
|
||||
dictionary will result in its associated bounding box being ignored.
|
||||
"""
|
||||
|
||||
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
|
||||
|
||||
text_regions = []
|
||||
for idx in range(len(ocr_data)):
|
||||
res = ocr_data[idx]
|
||||
@ -130,12 +141,12 @@ class OCRAgentPaddle(OCRAgent):
|
||||
continue
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text:
|
||||
text_region = TextRegion.from_coords(
|
||||
text_region = build_text_region_from_coords(
|
||||
x1,
|
||||
y1,
|
||||
x2,
|
||||
y2,
|
||||
cleaned_text,
|
||||
text=cleaned_text,
|
||||
source=Source.OCR_PADDLE,
|
||||
)
|
||||
text_regions.append(text_region)
|
||||
|
@ -1,18 +1,13 @@
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import unstructured_pytesseract
|
||||
from PIL import Image as PILImage
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
)
|
||||
from unstructured_pytesseract import Output
|
||||
|
||||
from unstructured.logger import logger
|
||||
from unstructured.partition.pdf_image.ocr import get_elements_from_ocr_regions
|
||||
from unstructured.partition.utils.config import env_config
|
||||
from unstructured.partition.utils.constants import (
|
||||
IMAGE_COLOR_DEPTH,
|
||||
@ -21,6 +16,13 @@ from unstructured.partition.utils.constants import (
|
||||
Source,
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_inference.inference.elements import TextRegion
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
)
|
||||
|
||||
|
||||
class OCRAgentTesseract(OCRAgent):
|
||||
@ -38,7 +40,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
|
||||
def get_layout_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List[TextRegion]:
|
||||
) -> List["TextRegion"]:
|
||||
"""Get the OCR regions from image as a list of text regions with tesseract."""
|
||||
|
||||
logger.info("Processing entire page OCR with tesseract...")
|
||||
@ -83,9 +85,14 @@ class OCRAgentTesseract(OCRAgent):
|
||||
|
||||
return ocr_regions
|
||||
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def get_layout_elements_from_image(
|
||||
self, image: PILImage, ocr_languages: str = "eng"
|
||||
) -> List[LayoutElement]:
|
||||
) -> List["LayoutElement"]:
|
||||
from unstructured.partition.pdf_image.inference_utils import (
|
||||
build_layout_elements_from_ocr_regions,
|
||||
)
|
||||
|
||||
ocr_regions = self.get_layout_from_image(
|
||||
image,
|
||||
ocr_languages=ocr_languages,
|
||||
@ -103,13 +110,14 @@ class OCRAgentTesseract(OCRAgent):
|
||||
ocr_languages=ocr_languages,
|
||||
)
|
||||
|
||||
return get_elements_from_ocr_regions(
|
||||
return build_layout_elements_from_ocr_regions(
|
||||
ocr_regions=ocr_regions,
|
||||
ocr_text=ocr_text,
|
||||
group_by_ocr_text=True,
|
||||
)
|
||||
|
||||
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List[TextRegion]:
|
||||
@requires_dependencies("unstructured_inference")
|
||||
def parse_data(self, ocr_data: pd.DataFrame, zoom: float = 1) -> List["TextRegion"]:
|
||||
"""
|
||||
Parse the OCR result data to extract a list of TextRegion objects from
|
||||
tesseract.
|
||||
@ -137,6 +145,8 @@ class OCRAgentTesseract(OCRAgent):
|
||||
data frame will result in its associated bounding box being ignored.
|
||||
"""
|
||||
|
||||
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
|
||||
|
||||
if zoom <= 0:
|
||||
zoom = 1
|
||||
|
||||
@ -153,7 +163,7 @@ class OCRAgentTesseract(OCRAgent):
|
||||
y1 = idtx.top / zoom
|
||||
x2 = (idtx.left + idtx.width) / zoom
|
||||
y2 = (idtx.top + idtx.height) / zoom
|
||||
text_region = TextRegion.from_coords(
|
||||
text_region = build_text_region_from_coords(
|
||||
x1,
|
||||
y1,
|
||||
x2,
|
||||
|
@ -1,11 +1,13 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from unstructured_inference.inference.layout import DocumentLayout
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from unstructured.partition.utils.constants import Source
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_inference.inference.layout import DocumentLayout
|
||||
|
||||
def clean_pdfminer_inner_elements(document: DocumentLayout) -> DocumentLayout:
|
||||
|
||||
def clean_pdfminer_inner_elements(document: "DocumentLayout") -> "DocumentLayout":
|
||||
"""Clean pdfminer elements from inside tables and stores them in extra_info dictionary
|
||||
with the table id as key"""
|
||||
defaultdict(list)
|
||||
|
Loading…
x
Reference in New Issue
Block a user