2025-01-23 11:11:38 -06:00
|
|
|
import numpy as np
|
2023-08-31 22:15:10 -05:00
|
|
|
import pytest
|
2025-01-23 11:11:38 -06:00
|
|
|
from unstructured_inference.inference.elements import TextRegions
|
2023-08-31 22:15:10 -05:00
|
|
|
|
2023-09-10 19:29:49 -07:00
|
|
|
from unstructured.documents.coordinates import PixelSpace
|
|
|
|
from unstructured.documents.elements import CoordinatesMetadata, Element, Text
|
|
|
|
from unstructured.partition.utils.constants import SORT_MODE_BASIC, SORT_MODE_XY_CUT
|
|
|
|
from unstructured.partition.utils.sorting import (
|
|
|
|
coord_has_valid_points,
|
2023-09-28 20:48:02 -07:00
|
|
|
coordinates_to_bbox,
|
|
|
|
shrink_bbox,
|
2023-09-10 19:29:49 -07:00
|
|
|
sort_page_elements,
|
2025-01-23 11:11:38 -06:00
|
|
|
sort_text_regions,
|
2023-09-10 19:29:49 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-28 20:48:02 -07:00
|
|
|
class MockCoordinatesMetadata(CoordinatesMetadata):
|
|
|
|
def __init__(self, points):
|
|
|
|
system = PixelSpace(width=300, height=500)
|
|
|
|
|
|
|
|
super().__init__(points, system)
|
|
|
|
|
|
|
|
|
2023-09-10 19:29:49 -07:00
|
|
|
def test_coord_valid_coordinates():
|
|
|
|
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, 6), (7, 8)], PixelSpace)
|
|
|
|
assert coord_has_valid_points(coordinates) is True
|
|
|
|
|
|
|
|
|
|
|
|
def test_coord_missing_incomplete_point():
|
|
|
|
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, 6)], PixelSpace)
|
|
|
|
assert coord_has_valid_points(coordinates) is False
|
|
|
|
|
|
|
|
|
|
|
|
def test_coord_negative_values():
|
|
|
|
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, -6), (7, 8)], PixelSpace)
|
|
|
|
assert coord_has_valid_points(coordinates) is False
|
|
|
|
|
|
|
|
|
|
|
|
def test_coord_weird_values():
|
|
|
|
coordinates = CoordinatesMetadata([(1, 2), ("3", 4), (5, 6), (7, 8)], PixelSpace)
|
|
|
|
assert coord_has_valid_points(coordinates) is False
|
|
|
|
|
|
|
|
|
|
|
|
def test_coord_invalid_point_structure():
|
|
|
|
coordinates = CoordinatesMetadata([(1, 2), (3, 4, 5), (6, 7), (8, 9)], PixelSpace)
|
|
|
|
assert coord_has_valid_points(coordinates) is False
|
2023-08-31 22:15:10 -05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("sort_mode", ["xy-cut", "basic"])
|
|
|
|
def test_sort_page_elements_without_coordinates(sort_mode):
|
|
|
|
elements = [Element(str(idx)) for idx in range(5)]
|
|
|
|
assert sort_page_elements(elements) == elements
|
2023-09-10 19:29:49 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_sort_xycut_neg_coordinates():
|
|
|
|
elements = []
|
|
|
|
for idx in range(2):
|
|
|
|
elem = Text(str(idx))
|
|
|
|
elem.metadata.coordinates = CoordinatesMetadata(
|
|
|
|
[(0, idx), (3, 4), (6, 7), (8, 9)],
|
|
|
|
PixelSpace,
|
|
|
|
)
|
|
|
|
elements.append(elem)
|
|
|
|
|
|
|
|
# NOTE(crag): xycut not attempted, sort_page_elements returns original list
|
|
|
|
assert sort_page_elements(elements, sort_mode=SORT_MODE_XY_CUT) is not elements
|
|
|
|
|
|
|
|
|
|
|
|
def test_sort_xycut_pos_coordinates():
|
|
|
|
elements = []
|
|
|
|
for idx in range(2):
|
|
|
|
elem = Text(str(idx))
|
|
|
|
elem.metadata.coordinates = CoordinatesMetadata(
|
|
|
|
[(1, 2), (3, 4), (6, 7), (8, 9)],
|
|
|
|
PixelSpace,
|
|
|
|
)
|
|
|
|
elements.append(elem)
|
|
|
|
|
|
|
|
# NOTE(crag): xycut ran, so different list reference returned from input list
|
|
|
|
assert sort_page_elements(elements, sort_mode=SORT_MODE_XY_CUT) is not elements
|
|
|
|
|
|
|
|
|
|
|
|
def test_sort_basic_neg_coordinates():
|
|
|
|
elements = []
|
|
|
|
for idx in range(3):
|
|
|
|
elem = Text(str(idx))
|
|
|
|
elem.metadata.coordinates = CoordinatesMetadata(
|
|
|
|
[(1, -idx), (3, 4), (6, 7), (8, 9)],
|
|
|
|
PixelSpace,
|
|
|
|
)
|
|
|
|
elements.append(elem)
|
|
|
|
|
|
|
|
sorted_page_elements = sort_page_elements(elements, sort_mode=SORT_MODE_BASIC)
|
|
|
|
sorted_elem_text = " ".join([str(elem.text) for elem in sorted_page_elements])
|
|
|
|
assert sorted_elem_text == "2 1 0"
|
|
|
|
|
|
|
|
|
|
|
|
def test_sort_basic_pos_coordinates():
|
|
|
|
elements = []
|
|
|
|
for idx in range(3):
|
|
|
|
elem = Text(str(9 - idx))
|
|
|
|
elem.metadata.coordinates = CoordinatesMetadata(
|
|
|
|
[(1, 9 - idx), (3, 4), (6, 7), (8, 9)],
|
|
|
|
PixelSpace,
|
|
|
|
)
|
|
|
|
elements.append(elem)
|
|
|
|
|
|
|
|
sorted_page_elements = sort_page_elements(elements, sort_mode=SORT_MODE_BASIC)
|
|
|
|
assert sorted_page_elements is not elements
|
|
|
|
|
|
|
|
sorted_elem_text = " ".join([str(elem.text) for elem in sorted_page_elements])
|
|
|
|
assert sorted_elem_text == "7 8 9"
|
2023-09-28 20:48:02 -07:00
|
|
|
|
|
|
|
|
2025-01-23 11:11:38 -06:00
|
|
|
def test_sort_text_regions():
|
|
|
|
unsorted = TextRegions(
|
|
|
|
element_coords=np.array(
|
|
|
|
[[1, 2, 2, 2], [1, 1, 2, 2], [3, 1, 4, 4]],
|
|
|
|
),
|
|
|
|
texts=np.array(["1", "2", "3"]),
|
|
|
|
sources=np.array(["foo"] * 3),
|
|
|
|
)
|
|
|
|
assert sort_text_regions(unsorted, sort_mode=SORT_MODE_BASIC).texts.tolist() == ["2", "3", "1"]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"coords",
|
|
|
|
[
|
|
|
|
[[1, 2, 2, 2], [1, 1, 2, 2], [3, -1, 4, 4]],
|
|
|
|
[[1, 2, 2, 2], [1, 1, 2, 2], [3, None, 4, 4]],
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_sort_text_regions_with_invalid_coords_using_xy_cut_does_no_ops(coords):
|
|
|
|
unsorted = TextRegions(
|
|
|
|
element_coords=np.array(coords).astype(float),
|
|
|
|
texts=np.array(["1", "2", "3"]),
|
|
|
|
sources=np.array(["foo"] * 3),
|
|
|
|
)
|
|
|
|
assert sort_text_regions(unsorted).texts.tolist() == ["1", "2", "3"]
|
|
|
|
|
|
|
|
|
2023-09-28 20:48:02 -07:00
|
|
|
def test_coordinates_to_bbox():
|
|
|
|
coordinates_data = MockCoordinatesMetadata([(10, 20), (10, 200), (100, 200), (100, 20)])
|
|
|
|
expected_result = (10, 20, 100, 200)
|
|
|
|
assert coordinates_to_bbox(coordinates_data) == expected_result
|
|
|
|
|
|
|
|
|
|
|
|
def test_shrink_bbox():
|
2023-10-05 22:16:11 -07:00
|
|
|
bbox = (0, 0, 200, 100)
|
|
|
|
shrink_factor = 0.9
|
|
|
|
expected_result = (0, 0, 180, 90)
|
2023-09-28 20:48:02 -07:00
|
|
|
assert shrink_bbox(bbox, shrink_factor) == expected_result
|
|
|
|
|
2023-10-05 22:16:11 -07:00
|
|
|
bbox = (20, 20, 320, 120)
|
2023-09-28 20:48:02 -07:00
|
|
|
shrink_factor = 0.9
|
2023-10-05 22:16:11 -07:00
|
|
|
expected_result = (20, 20, 290, 110)
|
2023-09-28 20:48:02 -07:00
|
|
|
assert shrink_bbox(bbox, shrink_factor) == expected_result
|