From cc8e4b186302f6a2d3d0d46ff4c60ece1e97dd87 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 31 Mar 2025 10:59:38 -0700 Subject: [PATCH] Adding native support to convert pngs and jpgs to pdfs so the pipeline can work on them --- olmocr/image_utils.py | 40 ++++++++++++++++++++++++++++++++++++++++ olmocr/pipeline.py | 21 +++++++++++++++++---- pyproject.toml | 3 ++- 3 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 olmocr/image_utils.py diff --git a/olmocr/image_utils.py b/olmocr/image_utils.py new file mode 100644 index 0000000..960b505 --- /dev/null +++ b/olmocr/image_utils.py @@ -0,0 +1,40 @@ +import os +import tempfile +import subprocess + + +def convert_image_to_pdf_bytes(image_file: str) -> bytes: + try: + # Run img2pdf and capture its stdout directly as bytes + result = subprocess.run( + ["img2pdf", image_file], + check=True, + capture_output=True + ) + + # Return the stdout content which contains the PDF data + return result.stdout + + except subprocess.CalledProcessError as e: + # Raise error with stderr information if the conversion fails + raise RuntimeError(f"Error converting image to PDF: {e.stderr.decode('utf-8')}") + + +def is_png(file_path): + try: + with open(file_path, "rb") as f: + header = f.read(8) + return header == b"\x89PNG\r\n\x1a\n" + except Exception as e: + print(f"Error: {e}") + return False + + +def is_jpeg(file_path): + try: + with open(file_path, 'rb') as f: + header = f.read(2) + return header == b'\xff\xd8' + except Exception as e: + print(f"Error: {e}") + return False \ No newline at end of file diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index c4e3c29..5bf30c3 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -40,6 +40,7 @@ from olmocr.filter.filter import Language, PdfFilter from olmocr.metrics import MetricsKeeper, WorkerTracker from olmocr.prompts import PageResponse, build_finetuning_prompt from olmocr.prompts.anchor import get_anchor_text +from olmocr.image_utils import is_png, is_jpeg, convert_image_to_pdf_bytes from olmocr.s3_utils import ( download_zstd_csv, expand_s3_glob, @@ -326,6 +327,12 @@ async def process_pdf(args, worker_id: int, pdf_orig_path: str): else: raise + if is_png(tf.name) or is_jpeg(tf.name): + logger.info(f"Converting {pdf_orig_path} from image to PDF format...") + tf.seek(0) + tf.write(convert_image_to_pdf_bytes(tf.name)) + tf.flush() + try: reader = PdfReader(tf.name) num_pages = reader.get_num_pages() @@ -988,13 +995,16 @@ async def main(): logger.info(f"Expanding s3 glob at {pdf_path}") pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path)) elif os.path.exists(pdf_path): - if pdf_path.endswith(".pdf"): + if pdf_path.lower().endswith(".pdf") or pdf_path.lower().endswith(".png") or pdf_path.lower().endswith(".jpg") or pdf_path.lower().endswith(".jpeg"): if open(pdf_path, "rb").read(4) == b"%PDF": logger.info(f"Loading file at {pdf_path} as PDF document") pdf_work_paths.add(pdf_path) + elif is_png(pdf_path) or is_jpeg(pdf_path): + logger.info(f"Loading file at {pdf_path} as image document") + pdf_work_paths.add(pdf_path) else: logger.warning(f"File at {pdf_path} is not a valid PDF") - elif pdf_path.endswith(".txt"): + elif pdf_path.lower().endswith(".txt"): logger.info(f"Loading file at {pdf_path} as list of paths") with open(pdf_path, "r") as f: pdf_work_paths |= set(filter(None, (line.strip() for line in f))) @@ -1016,8 +1026,11 @@ async def main(): with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file: tmp_file.write(get_s3_bytes(pdf_s3, pdf)) tmp_file.flush() - reader = PdfReader(tmp_file.name) - page_counts.append(len(reader.pages)) + if is_png(tmp_file.name) or is_jpeg(tmp_file.name): + page_counts.append(1) + else: + reader = PdfReader(tmp_file.name) + page_counts.append(len(reader.pages)) except Exception as e: logger.warning(f"Failed to read {pdf}: {e}") diff --git a/pyproject.toml b/pyproject.toml index 9795f89..f23f645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "httpx", "torch>=2.5.1", "transformers==4.46.2", + "img2pdf", "beaker-py", ] license = {file = "LICENSE"} @@ -90,7 +91,7 @@ bench = [ "mistralai", "lxml", "flask", - "img2pdf", + ] train = [