Adding some rotation retry contrl

This commit is contained in:
Jake Poznanski 2024-10-28 20:16:06 +00:00
parent 7678f31aa9
commit 08d51b7183
4 changed files with 105 additions and 3 deletions

View File

@ -4,7 +4,7 @@ import boto3
import sqlite3
import orjson
import argparse
import uuid
import base64
import tempfile
import datetime
import posixpath
@ -15,6 +15,8 @@ import urllib3.exceptions
from dataclasses import dataclass
from pypdf import PdfReader
from io import BytesIO
from PIL import Image
from tqdm import tqdm
from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
@ -383,8 +385,23 @@ class BatchWriter:
thread.join()
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int) -> dict:
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
rotated_img = img.rotate(-image_rotation, expand=True)
# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")
# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)
return {
@ -511,6 +528,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
# TODO: If the rotation was previously invalid, then apply a rotation
# TODO: Try to provide a smaller prompt hint
else:

View File

@ -2,9 +2,12 @@ import unittest
from unittest.mock import MagicMock, patch
import hashlib
import json
import os
import base64
from PIL import Image
# Adjust the import path to match where your code resides
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query
class TestBuildDolmaDoc(unittest.TestCase):
@patch('pdelfin.birrpipeline.DatabaseManager')
@ -121,6 +124,87 @@ class TestBuildDolmaDoc(unittest.TestCase):
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
self.assertEqual(dolma_doc['id'], expected_id)
class TestBuildPageQuery(unittest.TestCase):
def testRotation(self):
# First, generate and save the non-rotated image
query = build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "edgar.pdf", 1, 1024, 6000, 0)
# Extract the base64 image from the query
image_content = query["chat_messages"][0]["content"][1]
self.assertEqual(image_content["type"], "image_url")
image_url = image_content["image_url"]["url"]
# Extract base64 string from the data URL
prefix = "data:image/png;base64,"
self.assertTrue(image_url.startswith(prefix))
image_base64 = image_url[len(prefix):]
# Decode the base64 string
image_data = base64.b64decode(image_base64)
# Define the output file path for the non-rotated image
output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")
# Save the non-rotated image to a file
with open(output_image_path, "wb") as image_file:
image_file.write(image_data)
# Now, generate and save the rotated image (90 degrees clockwise)
query_rotated = build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "edgar.pdf", 1, 1024, 6000, 90)
# Extract the base64 image from the rotated query
image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
self.assertEqual(image_content_rotated["type"], "image_url")
image_url_rotated = image_content_rotated["image_url"]["url"]
# Extract base64 string from the data URL for the rotated image
self.assertTrue(image_url_rotated.startswith(prefix))
image_base64_rotated = image_url_rotated[len(prefix):]
# Decode the base64 string for the rotated image
image_data_rotated = base64.b64decode(image_base64_rotated)
# Define the output file path for the rotated image
output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")
# Save the rotated image to a file
with open(output_image_rotated_path, "wb") as image_file_rotated:
image_file_rotated.write(image_data_rotated)
# Verification Step: Ensure the rotated image is 90 degrees clockwise rotated
# Open both images using PIL
with Image.open(output_image_path) as original_image:
with Image.open(output_image_rotated_path) as rotated_image:
# Compare pixel by pixel
original_pixels = original_image.load()
rotated_pixels = rotated_image.load()
width, height = original_image.size
self.assertEqual(width, rotated_image.size[1])
self.assertEqual(height, rotated_image.size[0])
for x in range(width):
for y in range(height):
self.assertEqual(
original_pixels[x, y], rotated_pixels[height - 1 - y, x],
f"Pixel mismatch at ({x}, {y})"
)
print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")
# Run the test
if __name__ == '__main__':
unittest.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 357 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 362 KiB