Christine Straub 5d14a2aea0
feat: shrink bboxes by top left (#1633)
Closes #1573.
### Summary
- update `shrink_bbox()` to keep top left rather than center
### Evaluation
Run the following command for this
[PDF](https://utic-dev-tech-fixtures.s3.us-east-2.amazonaws.com/pastebin/patent-11723901-page2.pdf).
```
PYTHONPATH=. python examples/custom-layout-order/evaluate_xy_cut_sorting.py <file_path> <strategy>
```

---------

Co-authored-by: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com>
2023-10-06 05:16:11 +00:00

128 lines
4.1 KiB
Python

import pytest
from unstructured.documents.coordinates import PixelSpace
from unstructured.documents.elements import CoordinatesMetadata, Element, Text
from unstructured.partition.utils.constants import SORT_MODE_BASIC, SORT_MODE_XY_CUT
from unstructured.partition.utils.sorting import (
coord_has_valid_points,
coordinates_to_bbox,
shrink_bbox,
sort_page_elements,
)
class MockCoordinatesMetadata(CoordinatesMetadata):
def __init__(self, points):
system = PixelSpace(width=300, height=500)
super().__init__(points, system)
def test_coord_valid_coordinates():
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, 6), (7, 8)], PixelSpace)
assert coord_has_valid_points(coordinates) is True
def test_coord_missing_incomplete_point():
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, 6)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
def test_coord_negative_values():
coordinates = CoordinatesMetadata([(1, 2), (3, 4), (5, -6), (7, 8)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
def test_coord_weird_values():
coordinates = CoordinatesMetadata([(1, 2), ("3", 4), (5, 6), (7, 8)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
def test_coord_invalid_point_structure():
coordinates = CoordinatesMetadata([(1, 2), (3, 4, 5), (6, 7), (8, 9)], PixelSpace)
assert coord_has_valid_points(coordinates) is False
@pytest.mark.parametrize("sort_mode", ["xy-cut", "basic"])
def test_sort_page_elements_without_coordinates(sort_mode):
elements = [Element(str(idx)) for idx in range(5)]
assert sort_page_elements(elements) == elements
def test_sort_xycut_neg_coordinates():
elements = []
for idx in range(2):
elem = Text(str(idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(0, idx), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
# NOTE(crag): xycut not attempted, sort_page_elements returns original list
assert sort_page_elements(elements, sort_mode=SORT_MODE_XY_CUT) is not elements
def test_sort_xycut_pos_coordinates():
elements = []
for idx in range(2):
elem = Text(str(idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(1, 2), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
# NOTE(crag): xycut ran, so different list reference returned from input list
assert sort_page_elements(elements, sort_mode=SORT_MODE_XY_CUT) is not elements
def test_sort_basic_neg_coordinates():
elements = []
for idx in range(3):
elem = Text(str(idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(1, -idx), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
sorted_page_elements = sort_page_elements(elements, sort_mode=SORT_MODE_BASIC)
sorted_elem_text = " ".join([str(elem.text) for elem in sorted_page_elements])
assert sorted_elem_text == "2 1 0"
def test_sort_basic_pos_coordinates():
elements = []
for idx in range(3):
elem = Text(str(9 - idx))
elem.metadata.coordinates = CoordinatesMetadata(
[(1, 9 - idx), (3, 4), (6, 7), (8, 9)],
PixelSpace,
)
elements.append(elem)
sorted_page_elements = sort_page_elements(elements, sort_mode=SORT_MODE_BASIC)
assert sorted_page_elements is not elements
sorted_elem_text = " ".join([str(elem.text) for elem in sorted_page_elements])
assert sorted_elem_text == "7 8 9"
def test_coordinates_to_bbox():
coordinates_data = MockCoordinatesMetadata([(10, 20), (10, 200), (100, 200), (100, 20)])
expected_result = (10, 20, 100, 200)
assert coordinates_to_bbox(coordinates_data) == expected_result
def test_shrink_bbox():
bbox = (0, 0, 200, 100)
shrink_factor = 0.9
expected_result = (0, 0, 180, 90)
assert shrink_bbox(bbox, shrink_factor) == expected_result
bbox = (20, 20, 320, 120)
shrink_factor = 0.9
expected_result = (20, 20, 290, 110)
assert shrink_bbox(bbox, shrink_factor) == expected_result