Adding native support to convert pngs and jpgs to pdfs so the pipeline can work on them

This commit is contained in:
Jake Poznanski 2025-03-31 10:59:38 -07:00
parent 0892b1829b
commit cc8e4b1863
3 changed files with 59 additions and 5 deletions

40
olmocr/image_utils.py Normal file
View File

@ -0,0 +1,40 @@
import os
import tempfile
import subprocess
def convert_image_to_pdf_bytes(image_file: str) -> bytes:
try:
# Run img2pdf and capture its stdout directly as bytes
result = subprocess.run(
["img2pdf", image_file],
check=True,
capture_output=True
)
# Return the stdout content which contains the PDF data
return result.stdout
except subprocess.CalledProcessError as e:
# Raise error with stderr information if the conversion fails
raise RuntimeError(f"Error converting image to PDF: {e.stderr.decode('utf-8')}")
def is_png(file_path):
try:
with open(file_path, "rb") as f:
header = f.read(8)
return header == b"\x89PNG\r\n\x1a\n"
except Exception as e:
print(f"Error: {e}")
return False
def is_jpeg(file_path):
try:
with open(file_path, 'rb') as f:
header = f.read(2)
return header == b'\xff\xd8'
except Exception as e:
print(f"Error: {e}")
return False

View File

@ -40,6 +40,7 @@ from olmocr.filter.filter import Language, PdfFilter
from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
from olmocr.image_utils import is_png, is_jpeg, convert_image_to_pdf_bytes
from olmocr.s3_utils import (
download_zstd_csv,
expand_s3_glob,
@ -326,6 +327,12 @@ async def process_pdf(args, worker_id: int, pdf_orig_path: str):
else:
raise
if is_png(tf.name) or is_jpeg(tf.name):
logger.info(f"Converting {pdf_orig_path} from image to PDF format...")
tf.seek(0)
tf.write(convert_image_to_pdf_bytes(tf.name))
tf.flush()
try:
reader = PdfReader(tf.name)
num_pages = reader.get_num_pages()
@ -988,13 +995,16 @@ async def main():
logger.info(f"Expanding s3 glob at {pdf_path}")
pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path))
elif os.path.exists(pdf_path):
if pdf_path.endswith(".pdf"):
if pdf_path.lower().endswith(".pdf") or pdf_path.lower().endswith(".png") or pdf_path.lower().endswith(".jpg") or pdf_path.lower().endswith(".jpeg"):
if open(pdf_path, "rb").read(4) == b"%PDF":
logger.info(f"Loading file at {pdf_path} as PDF document")
pdf_work_paths.add(pdf_path)
elif is_png(pdf_path) or is_jpeg(pdf_path):
logger.info(f"Loading file at {pdf_path} as image document")
pdf_work_paths.add(pdf_path)
else:
logger.warning(f"File at {pdf_path} is not a valid PDF")
elif pdf_path.endswith(".txt"):
elif pdf_path.lower().endswith(".txt"):
logger.info(f"Loading file at {pdf_path} as list of paths")
with open(pdf_path, "r") as f:
pdf_work_paths |= set(filter(None, (line.strip() for line in f)))
@ -1016,8 +1026,11 @@ async def main():
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
tmp_file.write(get_s3_bytes(pdf_s3, pdf))
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
if is_png(tmp_file.name) or is_jpeg(tmp_file.name):
page_counts.append(1)
else:
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")

View File

@ -37,6 +37,7 @@ dependencies = [
"httpx",
"torch>=2.5.1",
"transformers==4.46.2",
"img2pdf",
"beaker-py",
]
license = {file = "LICENSE"}
@ -90,7 +91,7 @@ bench = [
"mistralai",
"lxml",
"flask",
"img2pdf",
]
train = [