Yao You 8f2a719873
Feat/refactor layoutelement textregion to vectorized data structure (#3881)
This PR refactors the data structure for `list[LayoutElement]` and
`list[TextRegion]` used in partition pdf/image files.

- new data structure replaces a list of objects with one object with
`numpy` array to store data
- this only affects partition internal steps and it doesn't change input
or output signature of `partition` function itself, i.e., `partition`
still returns `list[Element]`
- internally `list[LayoutElement]` -> `LayoutElements`;
`list[TextRegion]` -> `TextRegions`
- current refactor stops before clean up pdfminer elements inside
inferred layout elements -> the algorithm of clean up needs to be
refactored before the data structure refactor can move forward. So
current refactor converts the array data structure into list data
structure with `element_array.as_list()` call. This is the last step
before turning `list[LayoutElement]` into `list[Element]` as return
- a future PR will update this last step so that we build
`list[Element]` from `LayoutElements` data structure instead.

The goal of this PR is to replace the data structure as much as possible
without changing underlying logic. There are a few places where the
slicing or filtering logic was simple enough to be converted into vector
data structure operations. Those are refactored to be vector based. As a
result there is some small improvements observed in ingest test. This is
likely because the vector operations cleaned up some previous
inconsistency in data types and operations.

---------

Co-authored-by: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com>
Co-authored-by: badGarnet <badGarnet@users.noreply.github.com>
2025-01-23 17:11:38 +00:00

158 lines
5.0 KiB
Python

import numpy as np
import pytest
from unstructured_inference.inference.elements import TextRegions
from unstructured.documents.coordinates import PixelSpace
from unstructured.documents.elements import CoordinatesMetadata, Element, Text
from unstructured.partition.utils.constants import SORT_MODE_BASIC, SORT_MODE_XY_CUT
from unstructured.partition.utils.sorting import (
coord_has_valid_points,
coordinates_to_bbox,
shrink_bbox,
sort_page_elements,
sort_text_regions,
)
class MockCoordinatesMetadata(CoordinatesMetadata):
def __init__(self, points):
system = PixelSpace(width=300, height=500)
super().__init__(points, system)
def test_coord_valid_coordinates():
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, 6), (7, 8)], PixelSpace)
assert coord_has_valid_points(coordinates) is True
def test_coord_missing_incomplete_point():
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, 6)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
def test_coord_negative_values():
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, -6), (7, 8)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
def test_coord_weird_values():
coordinates = CoordinatesMetadata([(1, 2), ("3", 4), (5, 6), (7, 8)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
def test_coord_invalid_point_structure():
coordinates = CoordinatesMetadata([(1, 2), (3, 4, 5), (6, 7), (8, 9)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
@pytest.mark.parametrize("sort_mode", ["xy-cut", "basic"])
def test_sort_page_elements_without_coordinates(sort_mode):
elements = [Element(str(idx)) for idx in range(5)]
assert sort_page_elements(elements) == elements
def test_sort_xycut_neg_coordinates():
elements = []
for idx in range(2):
elem = Text(str(idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(0, idx), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
# NOTE(crag): xycut not attempted, sort_page_elements returns original list
assert sort_page_elements(elements, sort_mode=SORT_MODE_XY_CUT) is not elements
def test_sort_xycut_pos_coordinates():
elements = []
for idx in range(2):
elem = Text(str(idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(1, 2), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
# NOTE(crag): xycut ran, so different list reference returned from input list
assert sort_page_elements(elements, sort_mode=SORT_MODE_XY_CUT) is not elements
def test_sort_basic_neg_coordinates():
elements = []
for idx in range(3):
elem = Text(str(idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(1, -idx), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
sorted_page_elements = sort_page_elements(elements, sort_mode=SORT_MODE_BASIC)
sorted_elem_text = " ".join([str(elem.text) for elem in sorted_page_elements])
assert sorted_elem_text == "2 1 0"
def test_sort_basic_pos_coordinates():
elements = []
for idx in range(3):
elem = Text(str(9 - idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(1, 9 - idx), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
sorted_page_elements = sort_page_elements(elements, sort_mode=SORT_MODE_BASIC)
assert sorted_page_elements is not elements
sorted_elem_text = " ".join([str(elem.text) for elem in sorted_page_elements])
assert sorted_elem_text == "7 8 9"
def test_sort_text_regions():
unsorted = TextRegions(
element_coords=np.array(
[[1, 2, 2, 2], [1, 1, 2, 2], [3, 1, 4, 4]],
),
texts=np.array(["1", "2", "3"]),
sources=np.array(["foo"] * 3),
)
assert sort_text_regions(unsorted, sort_mode=SORT_MODE_BASIC).texts.tolist() == ["2", "3", "1"]
@pytest.mark.parametrize(
"coords",
[
[[1, 2, 2, 2], [1, 1, 2, 2], [3, -1, 4, 4]],
[[1, 2, 2, 2], [1, 1, 2, 2], [3, None, 4, 4]],
],
)
def test_sort_text_regions_with_invalid_coords_using_xy_cut_does_no_ops(coords):
unsorted = TextRegions(
element_coords=np.array(coords).astype(float),
texts=np.array(["1", "2", "3"]),
sources=np.array(["foo"] * 3),
)
assert sort_text_regions(unsorted).texts.tolist() == ["1", "2", "3"]
def test_coordinates_to_bbox():
coordinates_data = MockCoordinatesMetadata([(10, 20), (10, 200), (100, 200), (100, 20)])
expected_result = (10, 20, 100, 200)
assert coordinates_to_bbox(coordinates_data) == expected_result
def test_shrink_bbox():
bbox = (0, 0, 200, 100)
shrink_factor = 0.9
expected_result = (0, 0, 180, 90)
assert shrink_bbox(bbox, shrink_factor) == expected_result
bbox = (20, 20, 320, 120)
shrink_factor = 0.9
expected_result = (20, 20, 290, 110)
assert shrink_bbox(bbox, shrink_factor) == expected_result