Cleaning up code for image to pdf conversion

This commit is contained in:
Jake Poznanski 2025-03-31 13:28:30 -07:00
parent cc8e4b1863
commit b64fd19db3
3 changed files with 78 additions and 26 deletions

View File

@ -7,11 +7,11 @@ import os
import tempfile import tempfile
from functools import partial from functools import partial
import img2pdf
from pypdf import PdfReader from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.image_utils import convert_image_to_pdf_bytes
def parse_method_arg(method_arg): def parse_method_arg(method_arg):
@ -116,15 +116,39 @@ async def process_pdfs(config, pdf_directory, data_directory, repeats, remove_te
pdf_relative_dir = os.path.dirname(relative_pdf_path) pdf_relative_dir = os.path.dirname(relative_pdf_path)
if remove_text: if remove_text:
page_images = []
for page_num in range(1, num_pages + 1):
page_images.append(render_pdf_to_base64png(pdf_path, page_num, target_longest_image_dim=2048))
print(f"Converting {pdf_path} into images to remove text-content...") print(f"Converting {pdf_path} into images to remove text-content...")
temp_pdf = tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False)
temp_pdf.write(img2pdf.convert([base64.b64decode(x) for x in page_images])) # Generate image files from each page
temp_pdf.flush() temp_image_files = []
pdf_path = temp_pdf.name try:
for page_num in range(1, num_pages + 1):
# Get base64 PNG data for the current page
base64_png = render_pdf_to_base64png(pdf_path, page_num, target_longest_image_dim=2048)
# Decode base64 and save to temporary file
temp_img = tempfile.NamedTemporaryFile("wb", suffix=".png", delete=False)
temp_img.write(base64.b64decode(base64_png))
temp_img.close()
temp_image_files.append(temp_img.name)
# Convert all images to a single PDF using our enhanced function
pdf_bytes = convert_image_to_pdf_bytes(temp_image_files)
# Write the PDF bytes to a temporary file
temp_pdf = tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False)
temp_pdf.write(pdf_bytes)
temp_pdf.close()
# Update pdf_path to the new file
pdf_path = temp_pdf.name
finally:
# Clean up temporary image files
for temp_file in temp_image_files:
try:
os.remove(temp_file)
except Exception as e:
print(f"Warning: Failed to remove temporary file {temp_file}: {e}")
for repeat in range(1, repeats + 1): for repeat in range(1, repeats + 1):
for page_num in range(1, num_pages + 1): for page_num in range(1, num_pages + 1):

View File

@ -1,23 +1,45 @@
import os import os
import tempfile
import subprocess import subprocess
import tempfile
from typing import List, Union
def convert_image_to_pdf_bytes(image_file: str) -> bytes: def convert_image_to_pdf_bytes(image_files: Union[str, List[str]]) -> bytes:
"""
Convert one or multiple image files to PDF bytes.
Args:
image_files: A single image file path (str) or a list of image file paths
Returns:
bytes: The PDF content as bytes
Raises:
RuntimeError: If the conversion fails
ValueError: If invalid input is provided
"""
# Handle different input types
if isinstance(image_files, str):
# Single image case
image_files = [image_files]
elif not isinstance(image_files, list) or not image_files:
raise ValueError("image_files must be a non-empty string or list of strings")
# Validate files exist and are valid image formats
for image_file in image_files:
if not os.path.exists(image_file):
raise ValueError(f"File does not exist: {image_file}")
try: try:
# Run img2pdf and capture its stdout directly as bytes # Run img2pdf with all images as arguments
result = subprocess.run( result = subprocess.run(["img2pdf"] + image_files, check=True, capture_output=True)
["img2pdf", image_file],
check=True,
capture_output=True
)
# Return the stdout content which contains the PDF data # Return the stdout content which contains the PDF data
return result.stdout return result.stdout
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
# Raise error with stderr information if the conversion fails # Raise error with stderr information if the conversion fails
raise RuntimeError(f"Error converting image to PDF: {e.stderr.decode('utf-8')}") raise RuntimeError(f"Error converting image(s) to PDF: {e.stderr.decode('utf-8')}")
def is_png(file_path): def is_png(file_path):
@ -32,9 +54,9 @@ def is_png(file_path):
def is_jpeg(file_path): def is_jpeg(file_path):
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
header = f.read(2) header = f.read(2)
return header == b'\xff\xd8' return header == b"\xff\xd8"
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return False return False

View File

@ -37,10 +37,10 @@ from olmocr.check import (
) )
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter.filter import Language, PdfFilter from olmocr.filter.filter import Language, PdfFilter
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
from olmocr.metrics import MetricsKeeper, WorkerTracker from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.prompts import PageResponse, build_finetuning_prompt from olmocr.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text from olmocr.prompts.anchor import get_anchor_text
from olmocr.image_utils import is_png, is_jpeg, convert_image_to_pdf_bytes
from olmocr.s3_utils import ( from olmocr.s3_utils import (
download_zstd_csv, download_zstd_csv,
expand_s3_glob, expand_s3_glob,
@ -89,7 +89,8 @@ process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() /
# Filter object, cached so it will only get loaded when/if you need it # Filter object, cached so it will only get loaded when/if you need it
get_pdf_filter = cache(lambda: PdfFilter(languages_to_keep={Language.ENGLISH, None}, apply_download_spam_check=True, apply_form_check=True)) get_pdf_filter = cache(lambda: PdfFilter(languages_to_keep={Language.ENGLISH, None}, apply_download_spam_check=True, apply_form_check=True))
SGLANG_SERVER_PORT = None # Specify a default port, but it can be overridden by args
SGLANG_SERVER_PORT = 30024
@dataclass(frozen=True) @dataclass(frozen=True)
@ -995,7 +996,12 @@ async def main():
logger.info(f"Expanding s3 glob at {pdf_path}") logger.info(f"Expanding s3 glob at {pdf_path}")
pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path)) pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path))
elif os.path.exists(pdf_path): elif os.path.exists(pdf_path):
if pdf_path.lower().endswith(".pdf") or pdf_path.lower().endswith(".png") or pdf_path.lower().endswith(".jpg") or pdf_path.lower().endswith(".jpeg"): if (
pdf_path.lower().endswith(".pdf")
or pdf_path.lower().endswith(".png")
or pdf_path.lower().endswith(".jpg")
or pdf_path.lower().endswith(".jpeg")
):
if open(pdf_path, "rb").read(4) == b"%PDF": if open(pdf_path, "rb").read(4) == b"%PDF":
logger.info(f"Loading file at {pdf_path} as PDF document") logger.info(f"Loading file at {pdf_path} as PDF document")
pdf_work_paths.add(pdf_path) pdf_work_paths.add(pdf_path)