mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-06-27 02:30:08 +00:00
Add functions for vectorized ops xy cut++
This commit is contained in:
parent
a46becc185
commit
06a1b66453
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ from unstructured.partition.utils.constants import SORT_MODE_BASIC, SORT_MODE_DO
|
|||||||
from unstructured.partition.utils.xycut import recursive_xy_cut, recursive_xy_cut_swapped
|
from unstructured.partition.utils.xycut import recursive_xy_cut, recursive_xy_cut_swapped
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from unstructured_inference.inference.elements import TextRegions
|
from unstructured_inference.inference.elements import TextRegion, TextRegions
|
||||||
|
|
||||||
|
|
||||||
def coordinates_to_bbox(coordinates: CoordinatesMetadata) -> tuple[int, int, int, int]:
|
def coordinates_to_bbox(coordinates: CoordinatesMetadata) -> tuple[int, int, int, int]:
|
||||||
@ -228,6 +228,110 @@ def sort_bboxes_by_xy_cut(
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def sort_bboxes_by_xy_cut_plus_plus(
|
||||||
|
bboxes: Sequence[Element], document_median_width: float, scaling_factor: float = 1.3
|
||||||
|
):
|
||||||
|
"""Sort bounding boxes using XY-cut-plus-plus algorithm."""
|
||||||
|
|
||||||
|
threshold_width = document_median_width * scaling_factor
|
||||||
|
shrunken_bboxes = []
|
||||||
|
for bbox in bboxes:
|
||||||
|
shrunken_bbox = shrink_bbox(bbox, shrink_factor)
|
||||||
|
shrunken_bboxes.append(shrunken_bbox)
|
||||||
|
|
||||||
|
|
||||||
|
def find_cross_layout_elements(text_regions: TextRegions, threshold_width: float) -> list[Element]:
|
||||||
|
"""Check if a bounding box is a cross-layout element. Returns a boolean mask selecting the cross
|
||||||
|
layout elements"""
|
||||||
|
textregion_widths = get_el_widths(text_regions)
|
||||||
|
return (textregion_widths > threshold_width) & (
|
||||||
|
elements_horizontally_overlap(text_regions=text_regions).sum(axis=-1) > 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_textregion_widths(text_regions: TextRegions) -> np.ndarray:
|
||||||
|
"""Get the widths of the elements in a TextRegions object."""
|
||||||
|
return text_regions.element_coords[:, 2] - text_regions.element_coords[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def textregions_horizontally_overlap(text_regions: TextRegions) -> np.ndarray:
|
||||||
|
"""Check if elements horizontally overlap."""
|
||||||
|
x_0s = text_regions.element_coords[:, [0]]
|
||||||
|
x_1s = text_regions.element_coords[:, [2]]
|
||||||
|
return (x_0s < x_1s.T) & (x_1s > x_0s.T)
|
||||||
|
|
||||||
|
|
||||||
|
def textregions_vertically_overlap(text_regions: TextRegions) -> np.ndarray:
|
||||||
|
"""Check if elements vertically overlap."""
|
||||||
|
y_0s = text_regions.element_coords[:, [1]]
|
||||||
|
y_1s = text_regions.element_coords[:, [3]]
|
||||||
|
return (y_0s < y_1s.T) & (y_1s > y_0s.T)
|
||||||
|
|
||||||
|
|
||||||
|
def textregions_overlap(text_regions: TextRegions) -> np.ndarray:
|
||||||
|
"""Check if elements overlap."""
|
||||||
|
return textregions_vertically_overlap(text_regions) & textregions_horizontally_overlap(
|
||||||
|
text_regions
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_textregion_centers(text_regions: TextRegions) -> np.array:
|
||||||
|
xs = text_regions.element_coords[:, [0, 2]].mean(axis=-1)
|
||||||
|
ys = text_regions.element_coords[:, [1, 3]].mean(axis=-1)
|
||||||
|
return np.column_stack((xs, ys))
|
||||||
|
|
||||||
|
|
||||||
|
def get_central_text_regions(
|
||||||
|
text_regions: TextRegions, page_center: np.ndarray, page_distance: float, threshold: float
|
||||||
|
) -> TextRegions:
|
||||||
|
"""
|
||||||
|
Geometric presegmentation of text regions.
|
||||||
|
"""
|
||||||
|
textregion_centers = get_textregion_centers(text_regions)
|
||||||
|
distances_to_page_center = np.linalg.norm(textregion_centers - page_center, axis=-1)
|
||||||
|
return text_regions.slice(distances_to_page_center / page_distance < threshold)
|
||||||
|
|
||||||
|
|
||||||
|
def get_textregion_distances(textregions: TextRegions) -> np.ndarray:
|
||||||
|
x_overlaps = textregions_horizontally_overlap(textregions)
|
||||||
|
y_overlaps = textregions_vertically_overlap(textregions)
|
||||||
|
overlaps = x_overlaps & y_overlaps
|
||||||
|
x_diffs = textregions[:, [0]] - textregions[:, [2]].T
|
||||||
|
y_diffs = textregions[:, [1]] - textregions[:, [3]].T
|
||||||
|
x_distances = np.stack([x_diffs, x_diffs.T], axis=-1).max(axis=-1).clip(0)
|
||||||
|
y_distances = np.stack([y_diffs, y_diffs.T], axis=-1).max(axis=-1).clip(0)
|
||||||
|
return np.linalg.norm(np.stack([x_distances, y_distances], axis=-1), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_distances_to_nearest_non_overlapping_bounding_box(textregions: TextRegions) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Find the nearest non-overlapping bounding box.
|
||||||
|
"""
|
||||||
|
# Take the difference between each bounding box's left x-coordinate and the right x-coordinate
|
||||||
|
# of all other bounding boxes. When this difference is positive, that indicates a horizontal
|
||||||
|
# separation between the two bounding boxes. Do the same for the y-coordinates.
|
||||||
|
x_diffs = textregions.element_coords[:, [0]] - textregions.element_coords[:, [2]].T
|
||||||
|
y_diffs = textregions.element_coords[:, [1]] - textregions.element_coords[:, [3]].T
|
||||||
|
# For each pair A, B of bounding boxes, at most one of the above pairwise differences is
|
||||||
|
# positive. We capture this by taking the maximum of the two differences, and cull the negative
|
||||||
|
# values by clipping at 0.0.
|
||||||
|
x_distances = np.stack([x_diffs, x_diffs.T], axis=-1).max(axis=-1).clip(0.0)
|
||||||
|
y_distances = np.stack([y_diffs, y_diffs.T], axis=-1).max(axis=-1).clip(0.0)
|
||||||
|
# We now find the distance between the closest points in pairs of bounding boxes. By taking the
|
||||||
|
# norm of the x and y distances. This works because a distance of 0.0 for an axis indicates
|
||||||
|
# overlap or adjacency on the axis, in which case the distance is the distance in the other axis
|
||||||
|
# (exactly what the Euclidean norm in 2d gives when one coordinate is zero and the other is
|
||||||
|
# positive). If both distances are positive, then the distance between the bounding boxes is
|
||||||
|
# exactly the distance between the closest corners of the bounding boxes, which is exactly the
|
||||||
|
# Euclidean norm of the horizontal and vertical separations.
|
||||||
|
distances = np.linalg.norm(np.stack([x_distances, y_distances], axis=-1), axis=-1)
|
||||||
|
# If the distance is 0.0, that means the bounding boxes are overlapping or adjacent, so we
|
||||||
|
# set the distance to infinity as a way of 'disqualifying' the bounding box.
|
||||||
|
# We then take the minimum of the distances from each bounding box to the others to get the
|
||||||
|
# distance to the 'nearest non-overlapping (or adjacent) bounding box'.
|
||||||
|
return np.where(distances == 0.0, np.inf, distances).min(axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def sort_text_regions(
|
def sort_text_regions(
|
||||||
elements: TextRegions,
|
elements: TextRegions,
|
||||||
sort_mode: str = SORT_MODE_XY_CUT,
|
sort_mode: str = SORT_MODE_XY_CUT,
|
||||||
@ -284,3 +388,10 @@ def sort_text_regions(
|
|||||||
sorted_elements = elements
|
sorted_elements = elements
|
||||||
|
|
||||||
return sorted_elements
|
return sorted_elements
|
||||||
|
|
||||||
|
|
||||||
|
def median_bounding_box_width(elements: TextRegions) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the median bounding box width of a list of TextRegion elements.
|
||||||
|
"""
|
||||||
|
return float(np.median(elements.element_coords[:, 2] - elements.element_coords[:, 0]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user