mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-28 07:34:13 +00:00
Anchor is fixed to sample text elements better
This commit is contained in:
parent
c8a4d14c57
commit
97291b3f6a
@ -87,11 +87,11 @@ def _transform_point(x, y, m):
|
||||
y_new = m[1]*x + m[3]*y + m[5]
|
||||
return x_new, y_new
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class Element:
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class BoundingBox:
|
||||
x0: float
|
||||
y0: float
|
||||
@ -102,23 +102,24 @@ class BoundingBox:
|
||||
def from_rectangle(rect: RectangleObject) -> "BoundingBox":
|
||||
return BoundingBox(rect[0], rect[1], rect[2], rect[3])
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class TextElement(Element):
|
||||
text: str
|
||||
x: float
|
||||
y: float
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ImageElement(Element):
|
||||
name: str
|
||||
bbox: BoundingBox
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class PageReport:
|
||||
mediabox: BoundingBox
|
||||
text_elements: List[TextElement]
|
||||
image_elements: List[ImageElement]
|
||||
|
||||
|
||||
def _pdf_report(local_pdf_path: str, page: int) -> PageReport:
|
||||
reader = PdfReader(local_pdf_path)
|
||||
page = reader.pages[page - 1]
|
||||
@ -219,27 +220,106 @@ def _merge_image_elements(images: List[ImageElement], tolerance: float=0.5) -> L
|
||||
return merged_images
|
||||
|
||||
|
||||
def _linearize_pdf_report(report: PageReport) -> str:
|
||||
def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
|
||||
result = ""
|
||||
|
||||
result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n"
|
||||
|
||||
#images = report.image_elements
|
||||
images = _merge_image_elements(report.image_elements)
|
||||
|
||||
for index, element in enumerate(images):
|
||||
result += f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]"
|
||||
# Process image elements
|
||||
image_strings = []
|
||||
for element in images:
|
||||
image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]"
|
||||
# Use element's unique identifier (e.g., id or position) for comparison
|
||||
image_strings.append((element, image_str))
|
||||
|
||||
for index, element in enumerate(report.text_elements):
|
||||
# Process text elements
|
||||
text_strings = []
|
||||
for element in report.text_elements:
|
||||
if len(element.text.strip()) == 0:
|
||||
continue
|
||||
|
||||
element_text = ftfy.fix_text(element.text)
|
||||
# Replace square brackets with something else not to throw off the syntax
|
||||
element_text = element_text.replace("[", "\[").replace("]", "\[")
|
||||
# Replace square brackets with escaped brackets
|
||||
element_text = element_text.replace("[", "\\[").replace("]", "\\]")
|
||||
|
||||
# Need to use ftfy to fix text, because occasionally there are invalid surrogate pairs and other UTF issues that cause
|
||||
# pyarrow to fail to load the json later
|
||||
result += f"[{element.x:.0f}x{element.y:.0f}]{element_text}"
|
||||
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}"
|
||||
text_strings.append((element, text_str))
|
||||
|
||||
# Combine all elements with their positions for sorting
|
||||
all_elements = []
|
||||
for elem, s in image_strings:
|
||||
position = (elem.bbox.x0, elem.bbox.y0)
|
||||
all_elements.append(('image', elem, s, position))
|
||||
for elem, s in text_strings:
|
||||
position = (elem.x, elem.y)
|
||||
all_elements.append(('text', elem, s, position))
|
||||
|
||||
# Calculate total length
|
||||
total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements)
|
||||
|
||||
if total_length <= max_length:
|
||||
# Include all elements
|
||||
for _, _, s, _ in all_elements:
|
||||
result += s
|
||||
return result
|
||||
|
||||
# Identify elements with min/max coordinates
|
||||
edge_elements = set()
|
||||
|
||||
if images:
|
||||
min_x0_image = min(images, key=lambda e: e.bbox.x0)
|
||||
max_x1_image = max(images, key=lambda e: e.bbox.x1)
|
||||
min_y0_image = min(images, key=lambda e: e.bbox.y0)
|
||||
max_y1_image = max(images, key=lambda e: e.bbox.y1)
|
||||
edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image])
|
||||
|
||||
if report.text_elements:
|
||||
text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0]
|
||||
min_x_text = min(text_elements, key=lambda e: e.x)
|
||||
max_x_text = max(text_elements, key=lambda e: e.x)
|
||||
min_y_text = min(text_elements, key=lambda e: e.y)
|
||||
max_y_text = max(text_elements, key=lambda e: e.y)
|
||||
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])
|
||||
|
||||
# Keep track of element IDs to prevent duplication
|
||||
selected_element_ids = set()
|
||||
selected_elements = []
|
||||
|
||||
# Include edge elements first
|
||||
for elem_type, elem, s, position in all_elements:
|
||||
if elem in edge_elements and id(elem) not in selected_element_ids:
|
||||
selected_elements.append((elem_type, elem, s, position))
|
||||
selected_element_ids.add(id(elem))
|
||||
|
||||
# Calculate remaining length
|
||||
current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements)
|
||||
remaining_length = max_length - current_length
|
||||
|
||||
# Exclude edge elements from the pool
|
||||
remaining_elements = [
|
||||
(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements
|
||||
if id(elem) not in selected_element_ids
|
||||
]
|
||||
|
||||
# Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
|
||||
remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
|
||||
|
||||
# Add elements until reaching max_length
|
||||
for elem_type, elem, s, position in remaining_elements:
|
||||
if current_length + len(s) > max_length:
|
||||
break
|
||||
selected_elements.append((elem_type, elem, s, position))
|
||||
selected_element_ids.add(id(elem))
|
||||
current_length += len(s)
|
||||
|
||||
# Sort selected elements by their positions to maintain logical order
|
||||
selected_elements.sort(key=lambda x: (x[3][0], x[3][1]))
|
||||
|
||||
# Build the final result
|
||||
for _, _, s, _ in selected_elements:
|
||||
result += s
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -185,10 +185,10 @@ def main():
|
||||
description="Transform JSONL files by extracting and renaming specific fields."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rewrite_finetuning_prompt',
|
||||
'--rewrite_prompt',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Rewrites the input prompt from standard OPENAI instruction format into our finetuned format"
|
||||
help="Rewrites the input prompt by reloading the pdf from source"
|
||||
)
|
||||
parser.add_argument(
|
||||
'input_dir',
|
||||
@ -233,7 +233,7 @@ def main():
|
||||
# Process files in parallel
|
||||
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
|
||||
executor.submit(process_file, input_file, output_file, args.rewrite_prompt): input_file
|
||||
for input_file, output_file in tasks
|
||||
}
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ class AnchorTest(unittest.TestCase):
|
||||
|
||||
print(anchor_text)
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 1000)
|
||||
|
||||
def testLargePromptHint2(self):
|
||||
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint2.pdf")
|
||||
@ -81,6 +82,7 @@ class AnchorTest(unittest.TestCase):
|
||||
|
||||
print(anchor_text)
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 4000)
|
||||
|
||||
def testNewsPaperPromptHint(self):
|
||||
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "newspaper.pdf")
|
||||
@ -89,6 +91,7 @@ class AnchorTest(unittest.TestCase):
|
||||
|
||||
print(anchor_text)
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 4000)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user