mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-06-27 02:30:08 +00:00
Add functions for vectorized ops xy cut++
This commit is contained in:
parent
a46becc185
commit
06a1b66453
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -11,7 +11,7 @@ from unstructured.partition.utils.constants import SORT_MODE_BASIC, SORT_MODE_DO
|
||||
from unstructured.partition.utils.xycut import recursive_xy_cut, recursive_xy_cut_swapped
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_inference.inference.elements import TextRegions
|
||||
from unstructured_inference.inference.elements import TextRegion, TextRegions
|
||||
|
||||
|
||||
def coordinates_to_bbox(coordinates: CoordinatesMetadata) -> tuple[int, int, int, int]:
|
||||
@ -228,6 +228,110 @@ def sort_bboxes_by_xy_cut(
|
||||
return res
|
||||
|
||||
|
||||
def sort_bboxes_by_xy_cut_plus_plus(
|
||||
bboxes: Sequence[Element], document_median_width: float, scaling_factor: float = 1.3
|
||||
):
|
||||
"""Sort bounding boxes using XY-cut-plus-plus algorithm."""
|
||||
|
||||
threshold_width = document_median_width * scaling_factor
|
||||
shrunken_bboxes = []
|
||||
for bbox in bboxes:
|
||||
shrunken_bbox = shrink_bbox(bbox, shrink_factor)
|
||||
shrunken_bboxes.append(shrunken_bbox)
|
||||
|
||||
|
||||
def find_cross_layout_elements(text_regions: TextRegions, threshold_width: float) -> list[Element]:
|
||||
"""Check if a bounding box is a cross-layout element. Returns a boolean mask selecting the cross
|
||||
layout elements"""
|
||||
textregion_widths = get_el_widths(text_regions)
|
||||
return (textregion_widths > threshold_width) & (
|
||||
elements_horizontally_overlap(text_regions=text_regions).sum(axis=-1) > 2
|
||||
)
|
||||
|
||||
|
||||
def get_textregion_widths(text_regions: TextRegions) -> np.ndarray:
|
||||
"""Get the widths of the elements in a TextRegions object."""
|
||||
return text_regions.element_coords[:, 2] - text_regions.element_coords[:, 0]
|
||||
|
||||
|
||||
def textregions_horizontally_overlap(text_regions: TextRegions) -> np.ndarray:
|
||||
"""Check if elements horizontally overlap."""
|
||||
x_0s = text_regions.element_coords[:, [0]]
|
||||
x_1s = text_regions.element_coords[:, [2]]
|
||||
return (x_0s < x_1s.T) & (x_1s > x_0s.T)
|
||||
|
||||
|
||||
def textregions_vertically_overlap(text_regions: TextRegions) -> np.ndarray:
|
||||
"""Check if elements vertically overlap."""
|
||||
y_0s = text_regions.element_coords[:, [1]]
|
||||
y_1s = text_regions.element_coords[:, [3]]
|
||||
return (y_0s < y_1s.T) & (y_1s > y_0s.T)
|
||||
|
||||
|
||||
def textregions_overlap(text_regions: TextRegions) -> np.ndarray:
|
||||
"""Check if elements overlap."""
|
||||
return textregions_vertically_overlap(text_regions) & textregions_horizontally_overlap(
|
||||
text_regions
|
||||
)
|
||||
|
||||
|
||||
def get_textregion_centers(text_regions: TextRegions) -> np.array:
|
||||
xs = text_regions.element_coords[:, [0, 2]].mean(axis=-1)
|
||||
ys = text_regions.element_coords[:, [1, 3]].mean(axis=-1)
|
||||
return np.column_stack((xs, ys))
|
||||
|
||||
|
||||
def get_central_text_regions(
|
||||
text_regions: TextRegions, page_center: np.ndarray, page_distance: float, threshold: float
|
||||
) -> TextRegions:
|
||||
"""
|
||||
Geometric presegmentation of text regions.
|
||||
"""
|
||||
textregion_centers = get_textregion_centers(text_regions)
|
||||
distances_to_page_center = np.linalg.norm(textregion_centers - page_center, axis=-1)
|
||||
return text_regions.slice(distances_to_page_center / page_distance < threshold)
|
||||
|
||||
|
||||
def get_textregion_distances(textregions: TextRegions) -> np.ndarray:
|
||||
x_overlaps = textregions_horizontally_overlap(textregions)
|
||||
y_overlaps = textregions_vertically_overlap(textregions)
|
||||
overlaps = x_overlaps & y_overlaps
|
||||
x_diffs = textregions[:, [0]] - textregions[:, [2]].T
|
||||
y_diffs = textregions[:, [1]] - textregions[:, [3]].T
|
||||
x_distances = np.stack([x_diffs, x_diffs.T], axis=-1).max(axis=-1).clip(0)
|
||||
y_distances = np.stack([y_diffs, y_diffs.T], axis=-1).max(axis=-1).clip(0)
|
||||
return np.linalg.norm(np.stack([x_distances, y_distances], axis=-1), axis=-1)
|
||||
|
||||
|
||||
def get_distances_to_nearest_non_overlapping_bounding_box(textregions: TextRegions) -> np.ndarray:
|
||||
"""
|
||||
Find the nearest non-overlapping bounding box.
|
||||
"""
|
||||
# Take the difference between each bounding box's left x-coordinate and the right x-coordinate
|
||||
# of all other bounding boxes. When this difference is positive, that indicates a horizontal
|
||||
# separation between the two bounding boxes. Do the same for the y-coordinates.
|
||||
x_diffs = textregions.element_coords[:, [0]] - textregions.element_coords[:, [2]].T
|
||||
y_diffs = textregions.element_coords[:, [1]] - textregions.element_coords[:, [3]].T
|
||||
# For each pair A, B of bounding boxes, at most one of the above pairwise differences is
|
||||
# positive. We capture this by taking the maximum of the two differences, and cull the negative
|
||||
# values by clipping at 0.0.
|
||||
x_distances = np.stack([x_diffs, x_diffs.T], axis=-1).max(axis=-1).clip(0.0)
|
||||
y_distances = np.stack([y_diffs, y_diffs.T], axis=-1).max(axis=-1).clip(0.0)
|
||||
# We now find the distance between the closest points in pairs of bounding boxes. By taking the
|
||||
# norm of the x and y distances. This works because a distance of 0.0 for an axis indicates
|
||||
# overlap or adjacency on the axis, in which case the distance is the distance in the other axis
|
||||
# (exactly what the Euclidean norm in 2d gives when one coordinate is zero and the other is
|
||||
# positive). If both distances are positive, then the distance between the bounding boxes is
|
||||
# exactly the distance between the closest corners of the bounding boxes, which is exactly the
|
||||
# Euclidean norm of the horizontal and vertical separations.
|
||||
distances = np.linalg.norm(np.stack([x_distances, y_distances], axis=-1), axis=-1)
|
||||
# If the distance is 0.0, that means the bounding boxes are overlapping or adjacent, so we
|
||||
# set the distance to infinity as a way of 'disqualifying' the bounding box.
|
||||
# We then take the minimum of the distances from each bounding box to the others to get the
|
||||
# distance to the 'nearest non-overlapping (or adjacent) bounding box'.
|
||||
return np.where(distances == 0.0, np.inf, distances).min(axis=-1)
|
||||
|
||||
|
||||
def sort_text_regions(
|
||||
elements: TextRegions,
|
||||
sort_mode: str = SORT_MODE_XY_CUT,
|
||||
@ -284,3 +388,10 @@ def sort_text_regions(
|
||||
sorted_elements = elements
|
||||
|
||||
return sorted_elements
|
||||
|
||||
|
||||
def median_bounding_box_width(elements: TextRegions) -> float:
|
||||
"""
|
||||
Calculate the median bounding box width of a list of TextRegion elements.
|
||||
"""
|
||||
return float(np.median(elements.element_coords[:, 2] - elements.element_coords[:, 0]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user