mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-11 07:58:10 +00:00
Adding some rotation retry contrl
This commit is contained in:
parent
7678f31aa9
commit
08d51b7183
@ -4,7 +4,7 @@ import boto3
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import orjson
|
import orjson
|
||||||
import argparse
|
import argparse
|
||||||
import uuid
|
import base64
|
||||||
import tempfile
|
import tempfile
|
||||||
import datetime
|
import datetime
|
||||||
import posixpath
|
import posixpath
|
||||||
@ -15,6 +15,8 @@ import urllib3.exceptions
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, List, Tuple, Dict, Callable, Any
|
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||||
@ -383,8 +385,23 @@ class BatchWriter:
|
|||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int) -> dict:
|
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||||
|
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
|
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
|
||||||
|
|
||||||
|
if image_rotation != 0:
|
||||||
|
image_bytes = base64.b64decode(image_base64)
|
||||||
|
with Image.open(BytesIO(image_bytes)) as img:
|
||||||
|
rotated_img = img.rotate(-image_rotation, expand=True)
|
||||||
|
|
||||||
|
# Save the rotated image to a bytes buffer
|
||||||
|
buffered = BytesIO()
|
||||||
|
rotated_img.save(buffered, format="PNG")
|
||||||
|
|
||||||
|
# Encode the rotated image back to base64
|
||||||
|
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -512,6 +529,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
|
|||||||
|
|
||||||
# TODO: If the rotation was previously invalid, then apply a rotation
|
# TODO: If the rotation was previously invalid, then apply a rotation
|
||||||
|
|
||||||
|
|
||||||
# TODO: Try to provide a smaller prompt hint
|
# TODO: Try to provide a smaller prompt hint
|
||||||
else:
|
else:
|
||||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
|
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
|
||||||
|
|||||||
@ -2,9 +2,12 @@ import unittest
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
# Adjust the import path to match where your code resides
|
# Adjust the import path to match where your code resides
|
||||||
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager
|
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query
|
||||||
|
|
||||||
class TestBuildDolmaDoc(unittest.TestCase):
|
class TestBuildDolmaDoc(unittest.TestCase):
|
||||||
@patch('pdelfin.birrpipeline.DatabaseManager')
|
@patch('pdelfin.birrpipeline.DatabaseManager')
|
||||||
@ -121,6 +124,87 @@ class TestBuildDolmaDoc(unittest.TestCase):
|
|||||||
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
|
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
|
||||||
self.assertEqual(dolma_doc['id'], expected_id)
|
self.assertEqual(dolma_doc['id'], expected_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildPageQuery(unittest.TestCase):
|
||||||
|
def testRotation(self):
|
||||||
|
# First, generate and save the non-rotated image
|
||||||
|
query = build_page_query(os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"gnarly_pdfs",
|
||||||
|
"edgar.pdf"
|
||||||
|
), "edgar.pdf", 1, 1024, 6000, 0)
|
||||||
|
|
||||||
|
# Extract the base64 image from the query
|
||||||
|
image_content = query["chat_messages"][0]["content"][1]
|
||||||
|
self.assertEqual(image_content["type"], "image_url")
|
||||||
|
image_url = image_content["image_url"]["url"]
|
||||||
|
|
||||||
|
# Extract base64 string from the data URL
|
||||||
|
prefix = "data:image/png;base64,"
|
||||||
|
self.assertTrue(image_url.startswith(prefix))
|
||||||
|
image_base64 = image_url[len(prefix):]
|
||||||
|
|
||||||
|
# Decode the base64 string
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# Define the output file path for the non-rotated image
|
||||||
|
output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")
|
||||||
|
|
||||||
|
# Save the non-rotated image to a file
|
||||||
|
with open(output_image_path, "wb") as image_file:
|
||||||
|
image_file.write(image_data)
|
||||||
|
|
||||||
|
# Now, generate and save the rotated image (90 degrees clockwise)
|
||||||
|
query_rotated = build_page_query(os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"gnarly_pdfs",
|
||||||
|
"edgar.pdf"
|
||||||
|
), "edgar.pdf", 1, 1024, 6000, 90)
|
||||||
|
|
||||||
|
# Extract the base64 image from the rotated query
|
||||||
|
image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
|
||||||
|
self.assertEqual(image_content_rotated["type"], "image_url")
|
||||||
|
image_url_rotated = image_content_rotated["image_url"]["url"]
|
||||||
|
|
||||||
|
# Extract base64 string from the data URL for the rotated image
|
||||||
|
self.assertTrue(image_url_rotated.startswith(prefix))
|
||||||
|
image_base64_rotated = image_url_rotated[len(prefix):]
|
||||||
|
|
||||||
|
# Decode the base64 string for the rotated image
|
||||||
|
image_data_rotated = base64.b64decode(image_base64_rotated)
|
||||||
|
|
||||||
|
# Define the output file path for the rotated image
|
||||||
|
output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")
|
||||||
|
|
||||||
|
# Save the rotated image to a file
|
||||||
|
with open(output_image_rotated_path, "wb") as image_file_rotated:
|
||||||
|
image_file_rotated.write(image_data_rotated)
|
||||||
|
|
||||||
|
# Verification Step: Ensure the rotated image is 90 degrees clockwise rotated
|
||||||
|
|
||||||
|
# Open both images using PIL
|
||||||
|
with Image.open(output_image_path) as original_image:
|
||||||
|
with Image.open(output_image_rotated_path) as rotated_image:
|
||||||
|
|
||||||
|
# Compare pixel by pixel
|
||||||
|
original_pixels = original_image.load()
|
||||||
|
rotated_pixels = rotated_image.load()
|
||||||
|
width, height = original_image.size
|
||||||
|
|
||||||
|
self.assertEqual(width, rotated_image.size[1])
|
||||||
|
self.assertEqual(height, rotated_image.size[0])
|
||||||
|
|
||||||
|
for x in range(width):
|
||||||
|
for y in range(height):
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
original_pixels[x, y], rotated_pixels[height - 1 - y, x],
|
||||||
|
f"Pixel mismatch at ({x}, {y})"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")
|
||||||
|
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
BIN
tests/test_renders/output_image.png
Normal file
BIN
tests/test_renders/output_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 357 KiB |
BIN
tests/test_renders/output_image_rotated90.png
Normal file
BIN
tests/test_renders/output_image_rotated90.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 362 KiB |
Loading…
x
Reference in New Issue
Block a user