Adding native support to convert pngs and jpgs to pdfs so the pipeline can work on them

This commit is contained in:
Jake Poznanski 2025-03-31 10:59:38 -07:00
parent 0892b1829b
commit cc8e4b1863
3 changed files with 59 additions and 5 deletions

40
olmocr/image_utils.py Normal file
View File

@ -0,0 +1,40 @@
import os
import tempfile
import subprocess
def convert_image_to_pdf_bytes(image_file: str) -> bytes:
try:
# Run img2pdf and capture its stdout directly as bytes
result = subprocess.run(
["img2pdf", image_file],
check=True,
capture_output=True
)
# Return the stdout content which contains the PDF data
return result.stdout
except subprocess.CalledProcessError as e:
# Raise error with stderr information if the conversion fails
raise RuntimeError(f"Error converting image to PDF: {e.stderr.decode('utf-8')}")
def is_png(file_path):
try:
with open(file_path, "rb") as f:
header = f.read(8)
return header == b"\x89PNG\r\n\x1a\n"
except Exception as e:
print(f"Error: {e}")
return False
def is_jpeg(file_path):
try:
with open(file_path, 'rb') as f:
header = f.read(2)
return header == b'\xff\xd8'
except Exception as e:
print(f"Error: {e}")
return False

View File

@ -40,6 +40,7 @@ from olmocr.filter.filter import Language, PdfFilter
from olmocr.metrics import MetricsKeeper, WorkerTracker from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.prompts import PageResponse, build_finetuning_prompt from olmocr.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text from olmocr.prompts.anchor import get_anchor_text
from olmocr.image_utils import is_png, is_jpeg, convert_image_to_pdf_bytes
from olmocr.s3_utils import ( from olmocr.s3_utils import (
download_zstd_csv, download_zstd_csv,
expand_s3_glob, expand_s3_glob,
@ -326,6 +327,12 @@ async def process_pdf(args, worker_id: int, pdf_orig_path: str):
else: else:
raise raise
if is_png(tf.name) or is_jpeg(tf.name):
logger.info(f"Converting {pdf_orig_path} from image to PDF format...")
tf.seek(0)
tf.write(convert_image_to_pdf_bytes(tf.name))
tf.flush()
try: try:
reader = PdfReader(tf.name) reader = PdfReader(tf.name)
num_pages = reader.get_num_pages() num_pages = reader.get_num_pages()
@ -988,13 +995,16 @@ async def main():
logger.info(f"Expanding s3 glob at {pdf_path}") logger.info(f"Expanding s3 glob at {pdf_path}")
pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path)) pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path))
elif os.path.exists(pdf_path): elif os.path.exists(pdf_path):
if pdf_path.endswith(".pdf"): if pdf_path.lower().endswith(".pdf") or pdf_path.lower().endswith(".png") or pdf_path.lower().endswith(".jpg") or pdf_path.lower().endswith(".jpeg"):
if open(pdf_path, "rb").read(4) == b"%PDF": if open(pdf_path, "rb").read(4) == b"%PDF":
logger.info(f"Loading file at {pdf_path} as PDF document") logger.info(f"Loading file at {pdf_path} as PDF document")
pdf_work_paths.add(pdf_path) pdf_work_paths.add(pdf_path)
elif is_png(pdf_path) or is_jpeg(pdf_path):
logger.info(f"Loading file at {pdf_path} as image document")
pdf_work_paths.add(pdf_path)
else: else:
logger.warning(f"File at {pdf_path} is not a valid PDF") logger.warning(f"File at {pdf_path} is not a valid PDF")
elif pdf_path.endswith(".txt"): elif pdf_path.lower().endswith(".txt"):
logger.info(f"Loading file at {pdf_path} as list of paths") logger.info(f"Loading file at {pdf_path} as list of paths")
with open(pdf_path, "r") as f: with open(pdf_path, "r") as f:
pdf_work_paths |= set(filter(None, (line.strip() for line in f))) pdf_work_paths |= set(filter(None, (line.strip() for line in f)))
@ -1016,8 +1026,11 @@ async def main():
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file: with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
tmp_file.write(get_s3_bytes(pdf_s3, pdf)) tmp_file.write(get_s3_bytes(pdf_s3, pdf))
tmp_file.flush() tmp_file.flush()
reader = PdfReader(tmp_file.name) if is_png(tmp_file.name) or is_jpeg(tmp_file.name):
page_counts.append(len(reader.pages)) page_counts.append(1)
else:
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e: except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}") logger.warning(f"Failed to read {pdf}: {e}")

View File

@ -37,6 +37,7 @@ dependencies = [
"httpx", "httpx",
"torch>=2.5.1", "torch>=2.5.1",
"transformers==4.46.2", "transformers==4.46.2",
"img2pdf",
"beaker-py", "beaker-py",
] ]
license = {file = "LICENSE"} license = {file = "LICENSE"}
@ -90,7 +91,7 @@ bench = [
"mistralai", "mistralai",
"lxml", "lxml",
"flask", "flask",
"img2pdf",
] ]
train = [ train = [