mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 08:43:32 +00:00
Adding native support to convert pngs and jpgs to pdfs so the pipeline can work on them
This commit is contained in:
parent
0892b1829b
commit
cc8e4b1863
40
olmocr/image_utils.py
Normal file
40
olmocr/image_utils.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import tempfile
|
||||
import subprocess
|
||||
|
||||
|
||||
def convert_image_to_pdf_bytes(image_file: str) -> bytes:
|
||||
try:
|
||||
# Run img2pdf and capture its stdout directly as bytes
|
||||
result = subprocess.run(
|
||||
["img2pdf", image_file],
|
||||
check=True,
|
||||
capture_output=True
|
||||
)
|
||||
|
||||
# Return the stdout content which contains the PDF data
|
||||
return result.stdout
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Raise error with stderr information if the conversion fails
|
||||
raise RuntimeError(f"Error converting image to PDF: {e.stderr.decode('utf-8')}")
|
||||
|
||||
|
||||
def is_png(file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
return header == b"\x89PNG\r\n\x1a\n"
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_jpeg(file_path):
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
header = f.read(2)
|
||||
return header == b'\xff\xd8'
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
@ -40,6 +40,7 @@ from olmocr.filter.filter import Language, PdfFilter
|
||||
from olmocr.metrics import MetricsKeeper, WorkerTracker
|
||||
from olmocr.prompts import PageResponse, build_finetuning_prompt
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.image_utils import is_png, is_jpeg, convert_image_to_pdf_bytes
|
||||
from olmocr.s3_utils import (
|
||||
download_zstd_csv,
|
||||
expand_s3_glob,
|
||||
@ -326,6 +327,12 @@ async def process_pdf(args, worker_id: int, pdf_orig_path: str):
|
||||
else:
|
||||
raise
|
||||
|
||||
if is_png(tf.name) or is_jpeg(tf.name):
|
||||
logger.info(f"Converting {pdf_orig_path} from image to PDF format...")
|
||||
tf.seek(0)
|
||||
tf.write(convert_image_to_pdf_bytes(tf.name))
|
||||
tf.flush()
|
||||
|
||||
try:
|
||||
reader = PdfReader(tf.name)
|
||||
num_pages = reader.get_num_pages()
|
||||
@ -988,13 +995,16 @@ async def main():
|
||||
logger.info(f"Expanding s3 glob at {pdf_path}")
|
||||
pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path))
|
||||
elif os.path.exists(pdf_path):
|
||||
if pdf_path.endswith(".pdf"):
|
||||
if pdf_path.lower().endswith(".pdf") or pdf_path.lower().endswith(".png") or pdf_path.lower().endswith(".jpg") or pdf_path.lower().endswith(".jpeg"):
|
||||
if open(pdf_path, "rb").read(4) == b"%PDF":
|
||||
logger.info(f"Loading file at {pdf_path} as PDF document")
|
||||
pdf_work_paths.add(pdf_path)
|
||||
elif is_png(pdf_path) or is_jpeg(pdf_path):
|
||||
logger.info(f"Loading file at {pdf_path} as image document")
|
||||
pdf_work_paths.add(pdf_path)
|
||||
else:
|
||||
logger.warning(f"File at {pdf_path} is not a valid PDF")
|
||||
elif pdf_path.endswith(".txt"):
|
||||
elif pdf_path.lower().endswith(".txt"):
|
||||
logger.info(f"Loading file at {pdf_path} as list of paths")
|
||||
with open(pdf_path, "r") as f:
|
||||
pdf_work_paths |= set(filter(None, (line.strip() for line in f)))
|
||||
@ -1016,8 +1026,11 @@ async def main():
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
|
||||
tmp_file.write(get_s3_bytes(pdf_s3, pdf))
|
||||
tmp_file.flush()
|
||||
reader = PdfReader(tmp_file.name)
|
||||
page_counts.append(len(reader.pages))
|
||||
if is_png(tmp_file.name) or is_jpeg(tmp_file.name):
|
||||
page_counts.append(1)
|
||||
else:
|
||||
reader = PdfReader(tmp_file.name)
|
||||
page_counts.append(len(reader.pages))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {pdf}: {e}")
|
||||
|
||||
|
@ -37,6 +37,7 @@ dependencies = [
|
||||
"httpx",
|
||||
"torch>=2.5.1",
|
||||
"transformers==4.46.2",
|
||||
"img2pdf",
|
||||
"beaker-py",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
@ -90,7 +91,7 @@ bench = [
|
||||
"mistralai",
|
||||
"lxml",
|
||||
"flask",
|
||||
"img2pdf",
|
||||
|
||||
]
|
||||
|
||||
train = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user