mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-29 16:50:49 +00:00
Merge branch 'main' of https://github.com/allenai/pdelfin into main
This commit is contained in:
commit
202d81cece
4
.gitignore
vendored
4
.gitignore
vendored
@ -2,7 +2,9 @@
|
||||
wandb/
|
||||
*histogram.png
|
||||
*.json
|
||||
|
||||
dolma_previews/*
|
||||
s2_previews/*
|
||||
gnarly_previews/*
|
||||
/*.html
|
||||
|
||||
|
||||
|
||||
@ -1,320 +0,0 @@
|
||||
import os
|
||||
import hashlib
|
||||
import boto3
|
||||
import sqlite3
|
||||
import json
|
||||
import argparse
|
||||
import glob
|
||||
import tempfile
|
||||
import posixpath
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, List, Tuple, Dict
|
||||
from urllib.parse import urlparse
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
from pdelfin.prompts import build_finetuning_prompt
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
|
||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||
s3 = boto3.client('s3')
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, s3_workspace: str):
|
||||
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
|
||||
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
|
||||
os.makedirs(home_cache_dir, exist_ok=True)
|
||||
self.db_path = os.path.join(home_cache_dir, 'index.db')
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
self.cursor = self.conn.cursor()
|
||||
self._initialize_tables()
|
||||
|
||||
def _initialize_tables(self):
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS page_results (
|
||||
s3_path TEXT,
|
||||
page_num INTEGER,
|
||||
start_index BIGINT,
|
||||
length BIGINT,
|
||||
finish_reason TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_path ON page_results(s3_path)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pdfs (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
num_pages INTEGER,
|
||||
status TEXT DEFAULT 'pending'
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
etag TEXT
|
||||
)
|
||||
""")
|
||||
# Generic metadata such as current round
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[str]:
|
||||
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def get_current_round(self):
|
||||
round_value = self.get_metadata("round")
|
||||
return int(round_value) if round_value else 0
|
||||
|
||||
def is_file_processed(self, s3_path, etag):
|
||||
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||
result = self.cursor.fetchone()
|
||||
return result is not None and result[0] == etag
|
||||
|
||||
def add_index_entries(self, index_entries: List['BatchInferenceLine']):
|
||||
if index_entries:
|
||||
self.cursor.executemany("""
|
||||
INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
|
||||
self.conn.commit()
|
||||
|
||||
def update_processed_file(self, s3_path, etag):
|
||||
self.cursor.execute("""
|
||||
INSERT INTO processed_files (s3_path, etag)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||
""", (s3_path, etag))
|
||||
self.conn.commit()
|
||||
|
||||
def pdf_exists(self, s3_path: str) -> bool:
|
||||
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
|
||||
return self.cursor.fetchone() is not None
|
||||
|
||||
def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
|
||||
try:
|
||||
self.cursor.execute("""
|
||||
INSERT INTO pdfs (s3_path, num_pages, status)
|
||||
VALUES (?, ?, ?)
|
||||
""", (s3_path, num_pages, status))
|
||||
self.conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
print(f"PDF with s3_path '{s3_path}' already exists.")
|
||||
|
||||
def get_pdf_status(self, s3_path: str) -> Optional[str]:
|
||||
self.cursor.execute("SELECT status FROM pdfs WHERE s3_path = ?", (s3_path,))
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchInferenceLine:
|
||||
s3_path: str
|
||||
page_num: int # 1 indexed!
|
||||
start_index: int
|
||||
length: int
|
||||
finish_reason: str
|
||||
error: Optional[str]
|
||||
|
||||
|
||||
def parse_s3_path(s3_path):
|
||||
if not s3_path.startswith('s3://'):
|
||||
raise ValueError('s3_path must start with s3://')
|
||||
path = s3_path[5:]
|
||||
bucket, _, prefix = path.partition('/')
|
||||
return bucket, prefix
|
||||
|
||||
def expand_s3_glob(s3_glob: str) -> Dict[str, str]:
|
||||
parsed = urlparse(s3_glob)
|
||||
bucket_name = parsed.netloc
|
||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||
pattern = os.path.basename(parsed.path)
|
||||
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
matched_files = {}
|
||||
for page in page_iterator:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||
matched_files[f"s3://{bucket_name}/{key}"] = obj['ETag'].strip('"')
|
||||
|
||||
return matched_files
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"custom_id": f"{pretty_pdf_path}-{page}",
|
||||
"chat_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 6000,
|
||||
}
|
||||
|
||||
|
||||
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
|
||||
# Build the range header if start_index and/or end_index are specified
|
||||
range_header = None
|
||||
if start_index is not None or end_index is not None:
|
||||
range_value = f"bytes={start_index or 0}-"
|
||||
if end_index is not None:
|
||||
range_value += str(end_index)
|
||||
range_header = {'Range': range_value}
|
||||
|
||||
if range_header:
|
||||
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
|
||||
else:
|
||||
obj = s3.get_object(Bucket=bucket, Key=key)
|
||||
|
||||
return obj['Body'].read()
|
||||
|
||||
|
||||
def parse_custom_id(custom_id: str) -> Tuple[str, int]:
|
||||
s3_path = custom_id[:custom_id.rindex("-")]
|
||||
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||
return s3_path, page_num
|
||||
|
||||
def process_jsonl_content(s3_path) -> List[BatchInferenceLine]:
|
||||
content = get_s3_bytes(s3_path).decode("utf-8")
|
||||
|
||||
start_index = 0
|
||||
index_entries = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
for line in lines:
|
||||
line_length = len(line)
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
s3_path, page_num = parse_custom_id(data["custom_id"])
|
||||
|
||||
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
|
||||
|
||||
index_entries.append(BatchInferenceLine(
|
||||
s3_path=s3_path,
|
||||
page_num=page_num,
|
||||
start_index=start_index,
|
||||
length=line_length,
|
||||
finish_reason=data["outputs"][0]["finish_reason"],
|
||||
error=data.get("completion_error", None)
|
||||
))
|
||||
except json.JSONDecodeError:
|
||||
pass # Handle JSON decode errors if necessary
|
||||
except Exception as e:
|
||||
print(f"Error processing line: {e}")
|
||||
|
||||
start_index += line_length
|
||||
|
||||
return index_entries
|
||||
|
||||
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(s3_path))
|
||||
tf.flush()
|
||||
|
||||
reader = PdfReader(tf.name)
|
||||
return reader.get_num_pages()
|
||||
except Exception as ex:
|
||||
print(f"Warning, could not add {s3_path} due to {ex}")
|
||||
|
||||
return None
|
||||
|
||||
def get_pdf_batch_inference_lines(s3_path: str, pages: list[int]) -> list[dict]:
|
||||
results = []
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(s3_path))
|
||||
tf.flush()
|
||||
|
||||
for page in pages:
|
||||
results.append(build_page_query(tf.name, s3_path, page))
|
||||
except Exception as ex:
|
||||
print(f"Warning, could not get batch inferences lines for {s3_path} due to {ex}")
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
|
||||
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
|
||||
args = parser.parse_args()
|
||||
|
||||
db = DatabaseManager(args.workspace)
|
||||
print(f"Loaded db at {db.db_path}")
|
||||
print(f"Current round is {db.get_current_round()}\n")
|
||||
|
||||
# One shared executor to rule them all
|
||||
executor = ProcessPoolExecutor()
|
||||
|
||||
# If you have new PDFs, add them to the list
|
||||
if args.add_pdfs:
|
||||
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
|
||||
|
||||
print(f"Querying all PDFs at {args.add_pdfs}")
|
||||
|
||||
all_pdfs = expand_s3_glob(args.add_pdfs)
|
||||
print(f"Found {len(all_pdfs)} total pdf paths")
|
||||
|
||||
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
||||
print(f"Need to import {len(all_pdfs)} total new pdf paths")
|
||||
|
||||
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
s3_path = future_to_path[future]
|
||||
num_pages = future.result()
|
||||
if num_pages and not db.pdf_exists(s3_path):
|
||||
db.add_pdf(s3_path, num_pages, "pending")
|
||||
|
||||
print("\n")
|
||||
|
||||
# Now build an index of all the pages that were processed within the workspace so far
|
||||
print("Indexing all batch inference sent to this workspace")
|
||||
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
|
||||
|
||||
inference_output_paths = [
|
||||
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
||||
if not db.is_file_processed(s3_path, etag)
|
||||
]
|
||||
|
||||
print(f"Found {len(inference_output_paths)} new batch inference results to index")
|
||||
|
||||
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
s3_path, etag = future_to_path[future]
|
||||
inference_lines = future.result()
|
||||
|
||||
db.add_index_entries(inference_lines)
|
||||
db.update_processed_file(s3_path, etag=etag)
|
||||
|
||||
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
|
||||
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
|
||||
687
pdelfin/birrpipeline.py
Normal file
687
pdelfin/birrpipeline.py
Normal file
@ -0,0 +1,687 @@
|
||||
import os
|
||||
import hashlib
|
||||
import boto3
|
||||
import sqlite3
|
||||
import json
|
||||
import argparse
|
||||
import glob
|
||||
import tempfile
|
||||
import datetime
|
||||
import posixpath
|
||||
import threading
|
||||
import logging
|
||||
import boto3.session
|
||||
import urllib3.exceptions
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
from urllib.parse import urlparse
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
from pdelfin.prompts import build_finetuning_prompt
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, put_s3_bytes
|
||||
|
||||
|
||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||
workspace_s3 = boto3.client('s3')
|
||||
pdf_s3 = boto3.client('s3')
|
||||
|
||||
# Quiet logs from pypdf and smart open
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
logging.getLogger("smart_open").setLevel(logging.ERROR)
|
||||
|
||||
class DatabaseManager:
|
||||
@dataclass(frozen=True)
|
||||
class BatchInferenceRecord:
|
||||
inference_s3_path: str
|
||||
pdf_s3_path: str
|
||||
page_num: int # 1 indexed!
|
||||
round: int
|
||||
start_index: int
|
||||
length: int
|
||||
finish_reason: str
|
||||
error: Optional[str]
|
||||
|
||||
def is_usable(self):
|
||||
return self.error is None and self.finish_reason == "stop"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PDFRecord:
|
||||
s3_path: str
|
||||
num_pages: int
|
||||
status: str
|
||||
|
||||
def __init__(self, s3_workspace: str):
|
||||
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
|
||||
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
|
||||
os.makedirs(home_cache_dir, exist_ok=True)
|
||||
self.db_path = os.path.join(home_cache_dir, 'index.db')
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
self.cursor = self.conn.cursor()
|
||||
self._initialize_tables()
|
||||
|
||||
def _initialize_tables(self):
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS page_results (
|
||||
inference_s3_path TEXT,
|
||||
pdf_s3_path TEXT,
|
||||
page_num INTEGER,
|
||||
round INTEGER,
|
||||
start_index BIGINT,
|
||||
length BIGINT,
|
||||
finish_reason TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pdfs (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
num_pages INTEGER,
|
||||
status TEXT DEFAULT 'pending'
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
etag TEXT
|
||||
)
|
||||
""")
|
||||
# Generic metadata such as current round
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[str]:
|
||||
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def set_metadata(self, key: str, value: str) -> None:
|
||||
self.cursor.execute("""
|
||||
INSERT INTO metadata (key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value=excluded.value
|
||||
""", (key, value))
|
||||
self.conn.commit()
|
||||
|
||||
def is_file_processed(self, s3_path, etag):
|
||||
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||
result = self.cursor.fetchone()
|
||||
return result is not None and result[0] == etag
|
||||
|
||||
def update_processed_file(self, s3_path, etag):
|
||||
self.cursor.execute("""
|
||||
INSERT INTO processed_files (s3_path, etag)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||
""", (s3_path, etag))
|
||||
self.conn.commit()
|
||||
|
||||
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
|
||||
if index_entries:
|
||||
self.cursor.executemany("""
|
||||
INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
|
||||
self.conn.commit()
|
||||
|
||||
def get_index_entries(self, pdf_s3_path: str) -> List[BatchInferenceRecord]:
|
||||
self.cursor.execute("""
|
||||
SELECT inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error
|
||||
FROM page_results
|
||||
WHERE pdf_s3_path = ?
|
||||
ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC
|
||||
""", (pdf_s3_path,))
|
||||
|
||||
rows = self.cursor.fetchall()
|
||||
|
||||
return [
|
||||
self.BatchInferenceRecord(
|
||||
inference_s3_path=row[0],
|
||||
pdf_s3_path=row[1],
|
||||
page_num=row[2],
|
||||
round=row[3],
|
||||
start_index=row[4],
|
||||
length=row[5],
|
||||
finish_reason=row[6],
|
||||
error=row[7]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_last_indexed_round(self) -> int:
|
||||
self.cursor.execute("""
|
||||
SELECT MAX(round)
|
||||
FROM page_results
|
||||
""")
|
||||
|
||||
result = self.cursor.fetchone()
|
||||
return -1 if result[0] is None else result[0]
|
||||
|
||||
def pdf_exists(self, s3_path: str) -> bool:
|
||||
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
|
||||
return self.cursor.fetchone() is not None
|
||||
|
||||
def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
|
||||
try:
|
||||
self.cursor.execute("""
|
||||
INSERT INTO pdfs (s3_path, num_pages, status)
|
||||
VALUES (?, ?, ?)
|
||||
""", (s3_path, num_pages, status))
|
||||
self.conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
print(f"PDF with s3_path '{s3_path}' already exists.")
|
||||
|
||||
def update_pdf_status(self, s3_path: str, new_status: str) -> None:
|
||||
self.cursor.execute("""
|
||||
UPDATE pdfs
|
||||
SET status = ?
|
||||
WHERE s3_path = ?
|
||||
""", (new_status, s3_path))
|
||||
self.conn.commit()
|
||||
|
||||
def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
|
||||
self.cursor.execute("""
|
||||
SELECT s3_path, num_pages, status
|
||||
FROM pdfs
|
||||
WHERE s3_path = ?
|
||||
""", (s3_path,))
|
||||
|
||||
row = self.cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return self.PDFRecord(
|
||||
s3_path=row[0],
|
||||
num_pages=row[1],
|
||||
status=row[2]
|
||||
)
|
||||
return None
|
||||
|
||||
def get_pdfs_by_status(self, status: str) -> List[PDFRecord]:
|
||||
self.cursor.execute("""
|
||||
SELECT s3_path, num_pages, status
|
||||
FROM pdfs
|
||||
WHERE status == ?
|
||||
ORDER BY s3_path DESC, num_pages DESC
|
||||
""", (status, ))
|
||||
|
||||
rows = self.cursor.fetchall()
|
||||
|
||||
return [
|
||||
self.PDFRecord(
|
||||
s3_path=row[0],
|
||||
num_pages=row[1],
|
||||
status=row[2]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
|
||||
# Writes batches of lines out to a set of files, keeping each file below some maximum size
|
||||
class BatchWriter:
|
||||
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
|
||||
self.output_prefix = output_prefix
|
||||
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
self.batch = []
|
||||
self.batch_size = 0
|
||||
self.after_flush = after_flush
|
||||
self.threads = []
|
||||
|
||||
parsed = urlparse(output_prefix)
|
||||
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
|
||||
|
||||
if not self.is_s3:
|
||||
os.makedirs(output_prefix, exist_ok=True)
|
||||
|
||||
def _compute_hash(self, content: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the given content."""
|
||||
sha1 = hashlib.sha1()
|
||||
sha1.update(content.encode('utf-8'))
|
||||
return sha1.hexdigest()[:20]
|
||||
|
||||
def _get_output_path(self, hash_str: str) -> str:
|
||||
"""Generate the full output path with hash in the filename."""
|
||||
parsed = urlparse(self.output_prefix)
|
||||
if self.is_s3:
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
if key and not key.endswith('/'):
|
||||
key += '/'
|
||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||
return f"s3://{bucket}/{full_key}"
|
||||
else:
|
||||
filename = f"output_{hash_str}.jsonl"
|
||||
return os.path.join(self.output_prefix, filename)
|
||||
|
||||
def write_line(self, line: Optional[str]):
|
||||
if line is None or not line.strip():
|
||||
return
|
||||
|
||||
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
|
||||
|
||||
if self.batch_size + line_size > self.max_size:
|
||||
self._write_batch()
|
||||
|
||||
self.batch.append(line)
|
||||
self.batch_size += line_size
|
||||
|
||||
def _write_batch(self):
|
||||
if not self.batch:
|
||||
return
|
||||
|
||||
batch_lines = self.batch.copy()
|
||||
batch_content = "\n".join(batch_lines) + "\n"
|
||||
hash_str = self._compute_hash(batch_content)
|
||||
output_path = self._get_output_path(hash_str)
|
||||
|
||||
# Start a new thread to write the batch
|
||||
thread = threading.Thread(
|
||||
target=self._write_batch_to_file,
|
||||
args=(batch_content, output_path, batch_lines)
|
||||
)
|
||||
thread.start()
|
||||
self.threads.append(thread)
|
||||
|
||||
# Clear the batch and batch_size
|
||||
self.batch = []
|
||||
self.batch_size = 0
|
||||
|
||||
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
|
||||
if self.is_s3:
|
||||
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
|
||||
else:
|
||||
with open(output_path, 'w', encoding='utf-8') as f_out:
|
||||
f_out.write(batch_content)
|
||||
|
||||
# After writing, call the after_flush callback if it is set
|
||||
if self.after_flush:
|
||||
self.after_flush(batch_lines)
|
||||
|
||||
def close(self):
|
||||
self._write_batch()
|
||||
# Wait for all threads to finish
|
||||
for thread in self.threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"custom_id": f"{pretty_pdf_path}-{page}",
|
||||
"chat_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
|
||||
content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
|
||||
|
||||
start_index = 0
|
||||
index_entries = []
|
||||
lines = content_bytes.splitlines(keepends=True) # Split content into lines as bytes
|
||||
for line in lines:
|
||||
line_length = len(line) # Length in bytes
|
||||
|
||||
try:
|
||||
# Decode the line for JSON processing
|
||||
line_str = line.decode('utf-8')
|
||||
data = json.loads(line_str)
|
||||
pdf_s3_path, page_num = parse_custom_id(data["custom_id"])
|
||||
|
||||
if data.get("completion_error", None) is not None:
|
||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path=inference_s3_path,
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=page_num,
|
||||
round=data["round"],
|
||||
start_index=start_index, # Byte offset in the original file
|
||||
length=line_length, # Length in bytes
|
||||
finish_reason="completion_error",
|
||||
error=data.get("completion_error", None)
|
||||
))
|
||||
else:
|
||||
# Try to parse the actual model response JSON
|
||||
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
|
||||
|
||||
try:
|
||||
model_response_json = json.loads(data["outputs"][0]["text"])
|
||||
|
||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path=inference_s3_path,
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=page_num,
|
||||
round=data["round"],
|
||||
start_index=start_index, # Byte offset in the original file
|
||||
length=line_length, # Length in bytes
|
||||
finish_reason=data["outputs"][0]["finish_reason"],
|
||||
error=data.get("completion_error", None)
|
||||
))
|
||||
except json.JSONDecodeError:
|
||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path=inference_s3_path,
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=page_num,
|
||||
round=data["round"],
|
||||
start_index=start_index, # Byte offset in the original file
|
||||
length=line_length, # Length in bytes
|
||||
finish_reason=data["outputs"][0]["finish_reason"],
|
||||
error="Could not parse model JSON output",
|
||||
))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error with JSON Decoding of inference in {inference_s3_path}")
|
||||
# TODO Maybe this needs to add an index error that this json is bad
|
||||
except Exception as e:
|
||||
print(f"Error processing line: {e}")
|
||||
|
||||
start_index += line_length # Increment by the number of bytes
|
||||
|
||||
return index_entries
|
||||
|
||||
|
||||
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(pdf_s3, s3_path))
|
||||
tf.flush()
|
||||
|
||||
reader = PdfReader(tf.name)
|
||||
return reader.get_num_pages()
|
||||
except Exception as ex:
|
||||
print(f"Warning, could not add {s3_path} due to {ex}")
|
||||
|
||||
return None
|
||||
|
||||
def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int) -> list[dict]:
|
||||
db = DatabaseManager(s3_workspace)
|
||||
|
||||
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||
new_queries = []
|
||||
|
||||
# Shortcut out of downloading the actual PDF
|
||||
if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)):
|
||||
return []
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
|
||||
tf.flush()
|
||||
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
# Is there an existing page that has no error
|
||||
if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages):
|
||||
continue
|
||||
|
||||
has_errored_previously = sum(page.page_num == target_page_num for page in existing_pages)
|
||||
|
||||
if has_errored_previously:
|
||||
# Retry the page up to 3 times
|
||||
for _ in range(3):
|
||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num), "round": cur_round})
|
||||
|
||||
# Optionally, you can implement more complex retry logic here
|
||||
else:
|
||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num), "round": cur_round})
|
||||
except Exception as ex:
|
||||
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
|
||||
|
||||
return new_queries
|
||||
|
||||
def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Optional[dict]:
|
||||
db = DatabaseManager(s3_workspace)
|
||||
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||
document_text = ""
|
||||
last_page_start_index = 0
|
||||
pdf_page_spans = []
|
||||
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
target_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num]
|
||||
|
||||
if len(target_pages) == 0:
|
||||
return None
|
||||
|
||||
target_page = target_pages[0]
|
||||
|
||||
target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path,
|
||||
start_index=target_page.start_index,
|
||||
end_index=target_page.start_index+target_page.length - 1)
|
||||
|
||||
target_data = json.loads(target_row.decode("utf-8"))
|
||||
|
||||
target_output = json.loads(target_data["outputs"][0]["text"])
|
||||
|
||||
if target_output["natural_text"] is not None:
|
||||
document_text += target_output["natural_text"] + "\n"
|
||||
|
||||
pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
|
||||
last_page_start_index = len(document_text)
|
||||
|
||||
metadata = {
|
||||
"Source-File": pdf.s3_path,
|
||||
"pdf-total-pages": pdf.num_pages,
|
||||
}
|
||||
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
||||
|
||||
dolma_doc = {
|
||||
"id": id_,
|
||||
"text": document_text,
|
||||
"source": "pdelfin",
|
||||
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"metadata": metadata,
|
||||
"attributes": {
|
||||
"pdf_page_numbers": pdf_page_spans
|
||||
}
|
||||
}
|
||||
|
||||
return dolma_doc
|
||||
|
||||
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
|
||||
db = DatabaseManager(s3_workspace)
|
||||
|
||||
for line in dolma_doc_lines:
|
||||
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
|
||||
|
||||
def get_current_round(s3_workspace: str) -> int:
|
||||
path = s3_workspace[5:]
|
||||
bucket, _, prefix = path.partition('/')
|
||||
|
||||
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
|
||||
paginator = workspace_s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
|
||||
|
||||
round_numbers = []
|
||||
for page in page_iterator:
|
||||
for common_prefix in page.get('CommonPrefixes', []):
|
||||
round_prefix = common_prefix.get('Prefix')
|
||||
# Extract 'round_X' from the prefix
|
||||
round_dir = posixpath.basename(posixpath.dirname(round_prefix))
|
||||
if round_dir.startswith('round_'):
|
||||
try:
|
||||
round_num = int(round_dir[len('round_'):])
|
||||
round_numbers.append(round_num)
|
||||
except ValueError:
|
||||
pass
|
||||
if round_numbers:
|
||||
current_round = max(round_numbers) + 1
|
||||
else:
|
||||
current_round = 0
|
||||
return current_round
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||
parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.workspace_profile:
|
||||
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||
workspace_s3 = workspace_session.client("s3")
|
||||
|
||||
if args.pdf_profile:
|
||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||
pdf_s3 = pdf_session.client("s3")
|
||||
|
||||
db = DatabaseManager(args.workspace)
|
||||
print(f"Loaded db at {db.db_path}")
|
||||
|
||||
current_round = get_current_round(args.workspace)
|
||||
print(f"Current round is {current_round}\n")
|
||||
|
||||
# One shared executor to rule them all
|
||||
executor = ProcessPoolExecutor()
|
||||
|
||||
# If you have new PDFs, step one is to add them to the list
|
||||
if args.add_pdfs:
|
||||
if args.add_pdfs.startswith("s3://"):
|
||||
print(f"Querying all PDFs at {args.add_pdfs}")
|
||||
|
||||
all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
|
||||
print(f"Found {len(all_pdfs):,} total pdf paths")
|
||||
elif os.path.exists(args.add_pdfs):
|
||||
with open(args.add_pdfs, "r") as f:
|
||||
all_pdfs = [line.strip() for line in f.readlines() if len(line.strip()) > 0]
|
||||
else:
|
||||
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||
|
||||
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
||||
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
|
||||
|
||||
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
s3_path = future_to_path[future]
|
||||
num_pages = future.result()
|
||||
if num_pages and not db.pdf_exists(s3_path):
|
||||
db.add_pdf(s3_path, num_pages, "pending")
|
||||
|
||||
print("\n")
|
||||
|
||||
# Now build an index of all the pages that were processed within the workspace so far
|
||||
print("Indexing all batch inference sent to this workspace")
|
||||
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
||||
|
||||
inference_output_paths = [
|
||||
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
||||
if not db.is_file_processed(s3_path, etag)
|
||||
]
|
||||
|
||||
print(f"Found {len(inference_output_paths):,} new batch inference results to index")
|
||||
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
s3_path, etag = future_to_path[future]
|
||||
|
||||
try:
|
||||
inference_records = future.result()
|
||||
|
||||
db.add_index_entries(inference_records)
|
||||
db.update_processed_file(s3_path, etag=etag)
|
||||
except urllib3.exceptions.SSLError:
|
||||
print(f"Cannot load inference file {s3_path} due to SSL error, will retry another time")
|
||||
|
||||
|
||||
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
|
||||
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
|
||||
if db.get_last_indexed_round() < current_round - 1:
|
||||
print(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}")
|
||||
potentially_done_pdfs = db.get_pdfs_by_status("pending")
|
||||
else:
|
||||
print(f"\nCreating batch inference files for new PDFs")
|
||||
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round): pdf for pdf in db.get_pdfs_by_status("pending")}
|
||||
potentially_done_pdfs = []
|
||||
lines_written = 0
|
||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
pdf = future_to_path[future]
|
||||
inference_lines = future.result()
|
||||
|
||||
if len(inference_lines) == 0:
|
||||
potentially_done_pdfs.append(pdf)
|
||||
|
||||
for line in inference_lines:
|
||||
lines_written += 1
|
||||
|
||||
if line is not None:
|
||||
new_inference_writer.write_line(json.dumps(line))
|
||||
|
||||
new_inference_writer.close()
|
||||
|
||||
if lines_written > 0:
|
||||
print(f"Added {lines_written:,} new batch inference requests")
|
||||
|
||||
# Now, finally, assemble any potentially done docs into dolma documents
|
||||
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
||||
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
pdf = future_to_path[future]
|
||||
dolma_doc = future.result()
|
||||
|
||||
if dolma_doc is not None:
|
||||
new_output_writer.write_line(json.dumps(dolma_doc))
|
||||
|
||||
new_output_writer.close()
|
||||
|
||||
print("\nFinal statistics:")
|
||||
|
||||
# Output the number of PDFs in each status "pending" and "completed"
|
||||
pending_pdfs = db.get_pdfs_by_status("pending")
|
||||
completed_pdfs = db.get_pdfs_by_status("completed")
|
||||
|
||||
print(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)")
|
||||
print(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)")
|
||||
|
||||
# For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason)
|
||||
total_rounds = db.get_last_indexed_round() + 1
|
||||
for round_num in range(total_rounds):
|
||||
db.cursor.execute("""
|
||||
SELECT COUNT(*), error, finish_reason
|
||||
FROM page_results
|
||||
WHERE round = ?
|
||||
GROUP BY error, finish_reason
|
||||
""", (round_num,))
|
||||
|
||||
results = db.cursor.fetchall()
|
||||
|
||||
total_pages = sum(count for count, _, _ in results)
|
||||
print(f"\nInference Round {round_num} - {total_pages:,} pages processed:")
|
||||
|
||||
for count, error, finish_reason in results:
|
||||
error_str = error if error is not None else "None"
|
||||
print(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages")
|
||||
|
||||
print("\nWork finished, waiting for all workers to finish cleaning up")
|
||||
executor.shutdown(wait=True)
|
||||
db.close()
|
||||
@ -2,14 +2,42 @@ import subprocess
|
||||
import base64
|
||||
import io
|
||||
from pypdf import PdfReader
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def render_pdf_to_base64png(local_pdf_path: str, page: int, target_longest_image_dim: int=2048):
|
||||
pdf = PdfReader(local_pdf_path)
|
||||
pdf_page = pdf.pages[page - 1]
|
||||
longest_dim = max(pdf_page.mediabox.width, pdf_page.mediabox.height)
|
||||
def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]:
|
||||
"""
|
||||
Get the MediaBox dimensions for a specific page in a PDF file using the pdfinfo command.
|
||||
|
||||
:param pdf_file: Path to the PDF file
|
||||
:param page_num: The page number for which to extract MediaBox dimensions
|
||||
:return: A dictionary containing MediaBox dimensions or None if not found
|
||||
"""
|
||||
# Construct the pdfinfo command to extract info for the specific page
|
||||
command = ['pdfinfo', '-f', str(page_num), '-l', str(page_num), '-box', local_pdf_path]
|
||||
|
||||
# Run the command using subprocess
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# Check if there is any error in executing the command
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"Error running pdfinfo: {result.stderr}")
|
||||
|
||||
# Parse the output to find MediaBox
|
||||
output = result.stdout
|
||||
media_box = None
|
||||
|
||||
for line in output.splitlines():
|
||||
if 'MediaBox' in line:
|
||||
media_box = line.split(':')[1].strip().split()
|
||||
media_box = [float(x) for x in media_box]
|
||||
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
|
||||
|
||||
raise ValueError("MediaBox not found in the PDF info.")
|
||||
|
||||
|
||||
def render_pdf_to_base64png(local_pdf_path: str, page_num: int, target_longest_image_dim: int=2048):
|
||||
longest_dim = max(get_pdf_media_box_width_height(local_pdf_path, page_num))
|
||||
|
||||
# Convert PDF page to PNG using pdftoppm
|
||||
pdftoppm_result = subprocess.run(
|
||||
@ -17,9 +45,9 @@ def render_pdf_to_base64png(local_pdf_path: str, page: int, target_longest_image
|
||||
"pdftoppm",
|
||||
"-png",
|
||||
"-f",
|
||||
str(page),
|
||||
str(page_num),
|
||||
"-l",
|
||||
str(page),
|
||||
str(page_num),
|
||||
"-r",
|
||||
str(target_longest_image_dim * 72 / longest_dim), # 72 pixels per point is the conversion factor
|
||||
local_pdf_path,
|
||||
@ -35,8 +63,67 @@ def render_pdf_to_base64png(local_pdf_path: str, page: int, target_longest_image
|
||||
def render_pdf_to_base64webp(local_pdf_path: str, page: int, target_longest_image_dim: int=1024):
|
||||
base64_png = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim)
|
||||
|
||||
png_image = Image.open(io.BytesIO(base64_png.encode("utf-8")))
|
||||
png_image = Image.open(io.BytesIO(base64.b64decode(base64_png)))
|
||||
webp_output = io.BytesIO()
|
||||
png_image.save(webp_output, format="WEBP")
|
||||
|
||||
return base64.b64encode(webp_output.getvalue()).decode("utf-8")
|
||||
return base64.b64encode(webp_output.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the (width, height) of a PNG image given its base64-encoded data,
|
||||
without base64-decoding the entire data or loading the PNG itself
|
||||
|
||||
Should be really fast to support filtering
|
||||
|
||||
Parameters:
|
||||
- base64_data (str): Base64-encoded PNG image data.
|
||||
|
||||
Returns:
|
||||
- tuple: (width, height) of the image.
|
||||
|
||||
Raises:
|
||||
- ValueError: If the data is not a valid PNG image or the required bytes are not found.
|
||||
"""
|
||||
# PNG signature is 8 bytes
|
||||
png_signature_base64 = base64.b64encode(b'\x89PNG\r\n\x1a\n').decode('ascii')
|
||||
if not base64_data.startswith(png_signature_base64[:8]):
|
||||
raise ValueError('Not a valid PNG file')
|
||||
|
||||
# Positions in the binary data where width and height are stored
|
||||
width_start = 16 # Byte position where width starts (0-based indexing)
|
||||
width_end = 20 # Byte position where width ends (exclusive)
|
||||
height_start = 20
|
||||
height_end = 24
|
||||
|
||||
# Compute the byte range needed (from width_start to height_end)
|
||||
start_byte = width_start
|
||||
end_byte = height_end
|
||||
|
||||
# Calculate base64 character positions
|
||||
# Each group of 3 bytes corresponds to 4 base64 characters
|
||||
base64_start = (start_byte // 3) * 4
|
||||
base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups
|
||||
|
||||
# Extract the necessary base64 substring
|
||||
base64_substring = base64_data[base64_start:base64_end]
|
||||
|
||||
# Decode only the necessary bytes
|
||||
decoded_bytes = base64.b64decode(base64_substring)
|
||||
|
||||
# Compute the offset within the decoded bytes
|
||||
offset = start_byte % 3
|
||||
|
||||
# Extract width and height bytes
|
||||
width_bytes = decoded_bytes[offset:offset+4]
|
||||
height_bytes = decoded_bytes[offset+4:offset+8]
|
||||
|
||||
if len(width_bytes) < 4 or len(height_bytes) < 4:
|
||||
raise ValueError('Insufficient data to extract dimensions')
|
||||
|
||||
# Convert bytes to integers
|
||||
width = int.from_bytes(width_bytes, 'big')
|
||||
height = int.from_bytes(height_bytes, 'big')
|
||||
|
||||
return width, height
|
||||
@ -10,9 +10,11 @@
|
||||
# coherency score best of these three
|
||||
import subprocess
|
||||
import re
|
||||
import random
|
||||
import ftfy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, List
|
||||
from functools import lru_cache
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
import pymupdf
|
||||
@ -24,7 +26,7 @@ from pypdf.generic import RectangleObject
|
||||
from pdelfin.prompts._adv_anchor import mult
|
||||
|
||||
|
||||
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"]) -> str:
|
||||
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int=4000) -> str:
|
||||
assert page > 0, "Pages are 1-indexed in pdf-land"
|
||||
|
||||
if pdf_engine == "pdftotext":
|
||||
@ -52,7 +54,7 @@ def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftote
|
||||
|
||||
return best_option
|
||||
elif pdf_engine == "pdfreport":
|
||||
return _linearize_pdf_report(_pdf_report(local_pdf_path, page))
|
||||
return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
|
||||
else:
|
||||
raise NotImplementedError("Unknown engine")
|
||||
|
||||
@ -119,10 +121,14 @@ class PageReport:
|
||||
text_elements: List[TextElement]
|
||||
image_elements: List[ImageElement]
|
||||
|
||||
@lru_cache(maxsize=5)
|
||||
def _get_cached_pdf_reader(local_pdf_path: str) -> PdfReader:
|
||||
# Cached, because you are going to often iterate through a whole pdf, so this will make it a lot faster on subsequent iterations
|
||||
return PdfReader(local_pdf_path)
|
||||
|
||||
def _pdf_report(local_pdf_path: str, page: int) -> PageReport:
|
||||
reader = PdfReader(local_pdf_path)
|
||||
page = reader.pages[page - 1]
|
||||
def _pdf_report(local_pdf_path: str, page_num: int) -> PageReport:
|
||||
reader = _get_cached_pdf_reader(local_pdf_path)
|
||||
page = reader.pages[page_num - 1]
|
||||
resources = page.get("/Resources", {})
|
||||
xobjects = resources.get("/XObject", {})
|
||||
text_elements, image_elements = [], []
|
||||
@ -330,7 +336,10 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
|
||||
]
|
||||
|
||||
# Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
|
||||
remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
|
||||
# remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
|
||||
|
||||
# Shuffle remaining elements randomly
|
||||
random.shuffle(remaining_elements)
|
||||
|
||||
# Add elements until reaching max_length
|
||||
for elem_type, elem, s, position in remaining_elements:
|
||||
|
||||
@ -1,275 +0,0 @@
|
||||
# The way this script works is it gets a list of pdfs to process
|
||||
# and an output/scratch folder location either locally or in s3 to work with
|
||||
# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM
|
||||
# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location
|
||||
# Then, you run your script again, and it will see that you have some valid output files
|
||||
# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done
|
||||
# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length)
|
||||
# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again
|
||||
# This process will keep going, until you run it with the --fallback option, at which point it will
|
||||
# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs
|
||||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import argparse
|
||||
import boto3
|
||||
import json
|
||||
import hashlib
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from typing import Generator, List
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
from pdelfin.prompts import build_finetuning_prompt
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.filter import PdfFilter
|
||||
|
||||
import logging
|
||||
import smart_open
|
||||
import posixpath # Import posixpath for S3 path handling
|
||||
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
|
||||
pdf_filter = PdfFilter()
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"custom_id": f"{pretty_pdf_path}-{page}",
|
||||
"chat_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 6000,
|
||||
}
|
||||
|
||||
def fetch_s3_file(s3_url: str, local_path: str) -> str:
|
||||
parsed = urlparse(s3_url)
|
||||
bucket_name = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
|
||||
s3 = boto3.client('s3')
|
||||
s3.download_file(bucket_name, key, local_path)
|
||||
return local_path
|
||||
|
||||
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
|
||||
if pdf_path.startswith("s3://"):
|
||||
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
|
||||
fetch_s3_file(pdf_path, local_pdf_path)
|
||||
else:
|
||||
local_pdf_path = pdf_path
|
||||
|
||||
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
|
||||
print(f"Skipping {local_pdf_path} due to common filter")
|
||||
return []
|
||||
|
||||
pretty_pdf_path = pdf_path
|
||||
|
||||
pdf = PdfReader(local_pdf_path)
|
||||
num_pages = len(pdf.pages)
|
||||
|
||||
sample_pages = list(range(1, num_pages + 1))
|
||||
result = []
|
||||
for page in sample_pages:
|
||||
try:
|
||||
query = build_page_query(local_pdf_path, pretty_pdf_path, page)
|
||||
result.append(query)
|
||||
except Exception as e:
|
||||
print(f"Error processing page {page} of {pdf_path}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def is_glob_pattern(path: str) -> bool:
|
||||
return any(char in path for char in ['*', '?', '[', ']'])
|
||||
|
||||
def expand_s3_glob(s3_glob: str) -> list:
|
||||
parsed = urlparse(s3_glob)
|
||||
bucket_name = parsed.netloc
|
||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||
pattern = os.path.basename(parsed.path)
|
||||
|
||||
s3 = boto3.client('s3')
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
matched_files = []
|
||||
for page in page_iterator:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||
matched_files.append(f"s3://{bucket_name}/{key}")
|
||||
|
||||
return matched_files
|
||||
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the given content."""
|
||||
sha1 = hashlib.sha1()
|
||||
sha1.update(content.encode('utf-8'))
|
||||
return sha1.hexdigest()[:20]
|
||||
|
||||
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
|
||||
"""Generate the full output path with hash in the filename."""
|
||||
parsed = urlparse(output_path)
|
||||
if parsed.scheme in ('s3', 's3a', 's3n'):
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
# Ensure the key is treated as a directory by appending a slash if not present
|
||||
if key and not key.endswith('/'):
|
||||
key += '/'
|
||||
# Use posixpath to correctly join S3 paths
|
||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||
return f"s3://{bucket}/{full_key}"
|
||||
else:
|
||||
dir_path = output_path
|
||||
filename = f"output_{hash_str}.jsonl"
|
||||
return os.path.join(dir_path, filename)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
|
||||
)
|
||||
parser.add_argument(
|
||||
"pdf_paths",
|
||||
nargs='*',
|
||||
help=(
|
||||
"List of PDF paths to process. If a single argument contains glob patterns (e.g., *.pdf or s3://bucket/pdfs/*.pdf), "
|
||||
"it will be expanded accordingly."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_list",
|
||||
type=str,
|
||||
help="Path to a file containing paths to PDFs, one per line."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_size_mb",
|
||||
type=int,
|
||||
default=250,
|
||||
help="Max number of MBs of entries to put in each birr workitem"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_filter",
|
||||
action="store_true",
|
||||
help="Disables the basic spam/language filtering so that ALL pdfs listed are used"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="mise_batch_data",
|
||||
help="Output destination (can be a local path or an S3 URI)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pdf_paths = []
|
||||
|
||||
# Load PDF paths from positional arguments or path_list
|
||||
if args.pdf_paths:
|
||||
for path in args.pdf_paths:
|
||||
if is_glob_pattern(path):
|
||||
glob_path = path
|
||||
if glob_path.startswith("s3://"):
|
||||
# Handle S3 globbing
|
||||
expanded_paths = expand_s3_glob(glob_path)
|
||||
pdf_paths.extend(expanded_paths)
|
||||
else:
|
||||
# Handle local filesystem globbing
|
||||
expanded_paths = glob.glob(glob_path, recursive=True)
|
||||
pdf_paths.extend(expanded_paths)
|
||||
else:
|
||||
pdf_paths.append(path)
|
||||
|
||||
if args.path_list:
|
||||
with open(args.path_list, 'r') as f:
|
||||
for line in f:
|
||||
path = line.strip()
|
||||
if path:
|
||||
pdf_paths.append(path)
|
||||
|
||||
# Remove duplicates and shuffle
|
||||
pdf_paths = list(set(pdf_paths))
|
||||
random.shuffle(pdf_paths)
|
||||
|
||||
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
||||
|
||||
# Prepare for output
|
||||
output_dir = args.output
|
||||
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
# Determine if output is S3
|
||||
parsed_output = urlparse(output_dir)
|
||||
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
|
||||
|
||||
# Initialize variables for batching
|
||||
batch = []
|
||||
batch_size = 0
|
||||
pdfs_with_output = 0
|
||||
|
||||
# Function to write a batch
|
||||
def write_batch(batch: List[dict]):
|
||||
nonlocal output_dir
|
||||
if not batch:
|
||||
return
|
||||
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
|
||||
hash_str = compute_hash(batch_content)
|
||||
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
|
||||
with smart_open.open(output_path_with_hash, 'w') as f_out:
|
||||
f_out.write(batch_content)
|
||||
print(f"Wrote batch to {output_path_with_hash}")
|
||||
|
||||
# Using ProcessPoolExecutor to process files concurrently
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = []
|
||||
|
||||
with tqdm(desc="Processing PDFs", leave=False, total=len(pdf_paths)) as pb:
|
||||
for pdf_path in pdf_paths:
|
||||
futures.append(executor.submit(process_pdf, pdf_path, args.no_filter))
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
request_results = future.result() # Get the result from the process
|
||||
|
||||
if request_results:
|
||||
pdfs_with_output += 1 # Increment if there's at least one result
|
||||
|
||||
for request_obj in request_results:
|
||||
request_json = json.dumps(request_obj)
|
||||
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
|
||||
|
||||
# Check if adding this entry would exceed the max size
|
||||
if batch_size + request_size > max_file_size:
|
||||
# Write the current batch
|
||||
write_batch(batch)
|
||||
# Reset the batch
|
||||
batch = []
|
||||
batch_size = 0
|
||||
|
||||
# Add the entry to the batch
|
||||
batch.append(request_obj)
|
||||
batch_size += request_size
|
||||
|
||||
pb.update(1)
|
||||
except Exception as e:
|
||||
print(f"Error processing a PDF: {str(e)}")
|
||||
|
||||
# Write any remaining batch
|
||||
write_batch(batch)
|
||||
|
||||
# Print the number of PDFs that resulted in at least one output
|
||||
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
||||
print(f"Now you should run these prompts through mise/birr")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
75
pdelfin/s3_utils.py
Normal file
75
pdelfin/s3_utils.py
Normal file
@ -0,0 +1,75 @@
|
||||
import os
|
||||
import glob
|
||||
import posixpath
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def parse_s3_path(s3_path: str) -> tuple[str, str]:
|
||||
if not s3_path.startswith('s3://'):
|
||||
raise ValueError('s3_path must start with s3://')
|
||||
parsed = urlparse(s3_path)
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
|
||||
return bucket, key
|
||||
|
||||
|
||||
def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
|
||||
parsed = urlparse(s3_glob)
|
||||
bucket_name = parsed.netloc
|
||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||
pattern = os.path.basename(parsed.path)
|
||||
|
||||
paginator = s3_client.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
matched_files = {}
|
||||
for page in page_iterator:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||
matched_files[f"s3://{bucket_name}/{key}"] = obj['ETag'].strip('"')
|
||||
|
||||
return matched_files
|
||||
|
||||
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
|
||||
# Build the range header if start_index and/or end_index are specified
|
||||
range_header = None
|
||||
if start_index is not None and end_index is not None:
|
||||
# Range: bytes=start_index-end_index
|
||||
range_value = f"bytes={start_index}-{end_index}"
|
||||
range_header = {'Range': range_value}
|
||||
elif start_index is not None and end_index is None:
|
||||
# Range: bytes=start_index-
|
||||
range_value = f"bytes={start_index}-"
|
||||
range_header = {'Range': range_value}
|
||||
elif start_index is None and end_index is not None:
|
||||
# Range: bytes=-end_index (last end_index bytes)
|
||||
range_value = f"bytes=-{end_index}"
|
||||
range_header = {'Range': range_value}
|
||||
|
||||
if range_header:
|
||||
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
|
||||
else:
|
||||
obj = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
|
||||
return obj['Body'].read()
|
||||
|
||||
def put_s3_bytes(s3_client, s3_path: str, data: bytes):
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
|
||||
s3_client.put_object(
|
||||
Bucket=bucket,
|
||||
Key=key,
|
||||
Body=data,
|
||||
ContentType='text/plain; charset=utf-8'
|
||||
)
|
||||
|
||||
def parse_custom_id(custom_id: str) -> tuple[str, int]:
|
||||
s3_path = custom_id[:custom_id.rindex("-")]
|
||||
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||
return s3_path, page_num
|
||||
@ -7,43 +7,33 @@ wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
# TODO This is not used
|
||||
format:
|
||||
instruction_template: "Original:"
|
||||
response_template: "Rewritten:"
|
||||
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
|
||||
chat_template: |
|
||||
{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
|
||||
{% if loop.last %}
|
||||
{{ '<|im_end|>'}}
|
||||
{% else %}
|
||||
{{ '<|im_end|>\n' }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
generate:
|
||||
max_length: 4096
|
||||
max_length: 8192
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
# These tend to be really big, so it's only practical to host them as parquets on weka, otherwise you may OOM or just never finish dataloading
|
||||
- name: openai_batch_data_v5_1_train
|
||||
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_parquet/*.parquet
|
||||
- name: openai_batch_data_v5_1_train
|
||||
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_parquet/*.parquet
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
- name: openai_batch_data_v5_1_iabooks_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
|
||||
valid_data:
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
|
||||
|
||||
|
||||
|
||||
@ -22,28 +22,6 @@ class ModelConfig:
|
||||
model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatConfig:
|
||||
"""Configuration for formatting the text that is input to the model."""
|
||||
|
||||
new_line_symbol: str = field(
|
||||
help="The symbol to use for new lines in the text; default is '\\n'.",
|
||||
default="\n",
|
||||
)
|
||||
system_message: Optional[str] = field(
|
||||
help="The system message to use for formatting the text; default is no system message.",
|
||||
default=None,
|
||||
)
|
||||
instruction_template: str = field(
|
||||
help="The template to use for formatting the input text", default="Original:"
|
||||
)
|
||||
response_template: str = field(help="The template to use for formatting the output text", default="Rewrite:")
|
||||
chat_template: Optional[str] = field(
|
||||
help="The template to use for formatting the chat text. If None, the default chat template will be used.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateConfig:
|
||||
max_length: int = field(help="The maximum length of the generated text", default=4096)
|
||||
@ -75,9 +53,9 @@ class AwsConfig:
|
||||
@dataclass
|
||||
class SourceConfig:
|
||||
name: str = field(help="The name of the source")
|
||||
parquet_path: Optional[str] = field(help="The s3/glob path to a bunch of parquet files for a preprocessed dataset.", default=None)
|
||||
query_glob_path: Optional[str] = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data", default=None)
|
||||
response_glob_path: Optional[str] = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai", default=None)
|
||||
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
|
||||
target_longest_image_dim: int = field(help="Dimensions to render the pdf page image to")
|
||||
target_anchor_text_len: int = field(help="Maximum amount of anchor text (aka prompt hint)")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -141,7 +119,6 @@ class TrainConfig:
|
||||
lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration")
|
||||
aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3")
|
||||
wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases")
|
||||
format: FormatConfig = field(default=FormatConfig(), help="Configuration for formatting the input/output text")
|
||||
train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data")
|
||||
valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data")
|
||||
generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation")
|
||||
@ -158,5 +135,4 @@ class DemoConfig:
|
||||
share: bool = field(default=False, help="Share the demo publicly.")
|
||||
|
||||
model: ModelConfig = field(default=ModelConfig())
|
||||
format: FormatConfig = field(default=FormatConfig())
|
||||
generate: GenerateConfig = field(default=GenerateConfig())
|
||||
|
||||
@ -5,21 +5,27 @@ import re
|
||||
import os
|
||||
import base64
|
||||
import glob
|
||||
import pypdf, pypdf.errors
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
from logging import Logger
|
||||
from filelock import FileLock
|
||||
|
||||
import boto3
|
||||
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
||||
from .core.config import DataConfig, SourceConfig
|
||||
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Quiet logs from pypdf and smart open
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
logging.getLogger("smart_open").setLevel(logging.ERROR)
|
||||
|
||||
def list_dataset_files(s3_glob_path: str):
|
||||
"""
|
||||
@ -67,142 +73,12 @@ def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the (width, height) of a PNG image given its base64-encoded data,
|
||||
without base64-decoding the entire data or loading the PNG itself
|
||||
|
||||
Should be really fast to support filtering
|
||||
|
||||
Parameters:
|
||||
- base64_data (str): Base64-encoded PNG image data.
|
||||
|
||||
Returns:
|
||||
- tuple: (width, height) of the image.
|
||||
|
||||
Raises:
|
||||
- ValueError: If the data is not a valid PNG image or the required bytes are not found.
|
||||
"""
|
||||
# PNG signature is 8 bytes
|
||||
png_signature_base64 = base64.b64encode(b'\x89PNG\r\n\x1a\n').decode('ascii')
|
||||
if not base64_data.startswith(png_signature_base64[:8]):
|
||||
raise ValueError('Not a valid PNG file')
|
||||
|
||||
# Positions in the binary data where width and height are stored
|
||||
width_start = 16 # Byte position where width starts (0-based indexing)
|
||||
width_end = 20 # Byte position where width ends (exclusive)
|
||||
height_start = 20
|
||||
height_end = 24
|
||||
|
||||
# Compute the byte range needed (from width_start to height_end)
|
||||
start_byte = width_start
|
||||
end_byte = height_end
|
||||
|
||||
# Calculate base64 character positions
|
||||
# Each group of 3 bytes corresponds to 4 base64 characters
|
||||
base64_start = (start_byte // 3) * 4
|
||||
base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups
|
||||
|
||||
# Extract the necessary base64 substring
|
||||
base64_substring = base64_data[base64_start:base64_end]
|
||||
|
||||
# Decode only the necessary bytes
|
||||
decoded_bytes = base64.b64decode(base64_substring)
|
||||
|
||||
# Compute the offset within the decoded bytes
|
||||
offset = start_byte % 3
|
||||
|
||||
# Extract width and height bytes
|
||||
width_bytes = decoded_bytes[offset:offset+4]
|
||||
height_bytes = decoded_bytes[offset+4:offset+8]
|
||||
|
||||
if len(width_bytes) < 4 or len(height_bytes) < 4:
|
||||
raise ValueError('Insufficient data to extract dimensions')
|
||||
|
||||
# Convert bytes to integers
|
||||
width = int.from_bytes(width_bytes, 'big')
|
||||
height = int.from_bytes(height_bytes, 'big')
|
||||
|
||||
return width, height
|
||||
|
||||
|
||||
|
||||
|
||||
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
|
||||
"""
|
||||
custom_id = query.get("custom_id", "")
|
||||
body = query.get("body", {})
|
||||
messages = body.get("messages", [])
|
||||
|
||||
input_prompt_text = ""
|
||||
input_prompt_image_base64 = ""
|
||||
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue # We are only interested in user messages
|
||||
|
||||
contents = message.get("content", [])
|
||||
for content_item in contents:
|
||||
if content_item.get("type") == "text":
|
||||
input_prompt_text = content_item.get("text", "")
|
||||
elif content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {}).get("url", "")
|
||||
if image_url.startswith("data:image"):
|
||||
# Extract base64 part from data URL
|
||||
try:
|
||||
base64_data = image_url.split(",", 1)[1]
|
||||
input_prompt_image_base64 = base64_data
|
||||
except IndexError:
|
||||
input_prompt_image_base64 = ""
|
||||
|
||||
# This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that
|
||||
# and reusing it
|
||||
# # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
|
||||
# # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
|
||||
# # so we're going to extract out just the raw extracted prompt text
|
||||
# pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
|
||||
|
||||
# # Use re.DOTALL to ensure that the dot matches newline characters
|
||||
# match = re.search(pattern, input_prompt_text, re.DOTALL)
|
||||
|
||||
# if match:
|
||||
# raw_page_text = match.group(1).strip()
|
||||
# else:
|
||||
# raw_page_text = ""
|
||||
|
||||
|
||||
# This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time
|
||||
s3_path = custom_id[:custom_id.rindex("-")]
|
||||
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
|
||||
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
|
||||
)
|
||||
|
||||
# Split the s3_path into bucket and key
|
||||
bucket_name = s3_path.split('s3://')[1].split('/')[0]
|
||||
s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:])
|
||||
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
s3_client.download_fileobj(bucket_name, s3_key, tf)
|
||||
|
||||
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"input_prompt_text": input_prompt_text,
|
||||
"input_prompt_image_base64": input_prompt_image_base64,
|
||||
"raw_page_text": raw_page_text,
|
||||
}
|
||||
|
||||
|
||||
def extract_openai_batch_response(example):
|
||||
custom_id = example.get("custom_id", None)
|
||||
|
||||
# Parse the custom id into an s3 document path and page number (1indexed)
|
||||
s3_path, page_num = parse_custom_id(custom_id)
|
||||
|
||||
response_body = example.get("response", {}).get("body", {})
|
||||
choices = response_body.get("choices", [])
|
||||
response = ""
|
||||
@ -213,58 +89,71 @@ def extract_openai_batch_response(example):
|
||||
response = message.get("content", "")
|
||||
finish_reason = first_choice.get("finish_reason", "")
|
||||
|
||||
return {"custom_id": custom_id, "response": response, "finish_reason": finish_reason}
|
||||
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
|
||||
# into its own columns
|
||||
|
||||
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
|
||||
|
||||
def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]):
|
||||
custom_id = query_example["custom_id"]
|
||||
def _cache_s3_file(s3_path: str, local_cache_dir: str):
|
||||
"""
|
||||
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
|
||||
"""
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
|
||||
if custom_id not in response_map:
|
||||
return {
|
||||
"response": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
# Define the local file path
|
||||
local_file_path = os.path.join(local_cache_dir, key.replace("/", "_"))
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
lock_file = f"{local_file_path}.lock"
|
||||
|
||||
response_row = response_data[response_map[custom_id]]
|
||||
# Use a file lock to prevent concurrent writes
|
||||
with FileLock(lock_file):
|
||||
if not os.path.exists(local_file_path):
|
||||
logger.info(f"Downloading {s3_path} to {local_file_path}")
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
|
||||
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
|
||||
)
|
||||
s3_client.download_file(bucket, key, local_file_path)
|
||||
else:
|
||||
logger.info(f"File {local_file_path} already exists, skipping download.")
|
||||
|
||||
return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]}
|
||||
return local_file_path
|
||||
|
||||
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
|
||||
"""
|
||||
Caches all S3 paths in the dataset to the local cache directory.
|
||||
"""
|
||||
# Define the download function to use in parallel processing
|
||||
def cache_file(example):
|
||||
s3_path = example["s3_path"]
|
||||
if s3_path:
|
||||
# Download the file and cache it locally
|
||||
local_path = _cache_s3_file(s3_path, pdf_cache_location)
|
||||
return {"local_pdf_path": local_path}
|
||||
return {"local_pdf_path": None}
|
||||
|
||||
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
|
||||
logger.info("Loading query and response datasets")
|
||||
query_data = load_jsonl_into_ds(query_glob_path)
|
||||
# Map the caching function to the dataset (with parallelism if needed)
|
||||
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
|
||||
|
||||
return dataset
|
||||
|
||||
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset:
|
||||
if pdf_cache_location is None:
|
||||
pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs')
|
||||
|
||||
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
|
||||
response_data = load_jsonl_into_ds(response_glob_path)
|
||||
|
||||
# Map the datasets down to the core fields that we're going to need to make them easier to process
|
||||
logger.info("Mapping query data")
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
|
||||
|
||||
logger.info("Mapping response data")
|
||||
response_data = response_data["train"]
|
||||
|
||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
|
||||
|
||||
# What we're going to do, is build an in-memory map for the response data from custom_id to row
|
||||
# This will let us do quick lookups when we do a merge step, but it will not scale past a certain point
|
||||
logger.info("Building custom_id to row map")
|
||||
custom_id_to_response_row = {}
|
||||
for row_id, entry in enumerate(response_data):
|
||||
custom_id_to_response_row[entry["custom_id"]] = row_id
|
||||
|
||||
logger.info("Running merge map")
|
||||
final_dataset = query_data.map(
|
||||
partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
|
||||
num_proc=num_proc
|
||||
)
|
||||
|
||||
# Don't include data where the model cut off due to a length issue, or moderation issue
|
||||
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
||||
logger.info("Filtering on finish_reason == stop")
|
||||
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
||||
|
||||
# Pick things that have a reasonable image size only
|
||||
def pick_image_sizes(x):
|
||||
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
|
||||
return 1800 <= max(width, height) <= 2200
|
||||
|
||||
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
|
||||
# Cache all the s3_paths that were accessed to a local storage location,
|
||||
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
|
||||
|
||||
return final_dataset
|
||||
|
||||
@ -4,19 +4,15 @@ from PIL import Image
|
||||
import base64
|
||||
import torch # Make sure to import torch as it's used in the DataCollator
|
||||
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.prompts import build_finetuning_prompt
|
||||
|
||||
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2200, max_response_len: int=2200):
|
||||
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
|
||||
return False
|
||||
|
||||
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
|
||||
return False
|
||||
|
||||
return True
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int):
|
||||
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
@ -24,9 +20,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": example["input_prompt_image_base64"] # Placeholder
|
||||
"image": base64_page_image
|
||||
},
|
||||
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
],
|
||||
}
|
||||
]
|
||||
@ -36,14 +32,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
|
||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
||||
# 2048 as we passed to OpenAI is too large for training
|
||||
width, height = main_image.size
|
||||
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
|
||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
||||
|
||||
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = processor(
|
||||
@ -84,36 +73,30 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
labels_full = np.full_like(input_ids, fill_value=-100)
|
||||
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
||||
|
||||
# TODO Maybe cap the max length
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
if add_batch_dim:
|
||||
return {
|
||||
"input_ids": input_ids[np.newaxis, ...],
|
||||
"attention_mask": attention_mask[np.newaxis, ...],
|
||||
"labels": labels_full[np.newaxis, ...],
|
||||
"pixel_values": inputs.pixel_values[np.newaxis, ...],
|
||||
"image_grid_thw": inputs["image_grid_thw"]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels_full,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels_full,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_qwen2_training(batch, processor):
|
||||
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int):
|
||||
# Process each example in the batch using the helper function
|
||||
processed_examples = []
|
||||
for i in range(len(batch["input_prompt_image_base64"])):
|
||||
for i in range(len(batch["response"])):
|
||||
example = {
|
||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
||||
"input_prompt_text": batch["input_prompt_text"][i],
|
||||
"raw_page_text": batch["raw_page_text"][i],
|
||||
"local_pdf_path": batch["local_pdf_path"][i],
|
||||
"page_num": batch["page_num"][i],
|
||||
"response": batch["response"][i]
|
||||
}
|
||||
processed_example = prepare_data_for_qwen2_training(example, processor)
|
||||
processed_example = prepare_data_for_qwen2_training(example, processor,
|
||||
target_longest_image_dim=target_longest_image_dim,
|
||||
target_anchor_text_len=target_anchor_text_len)
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return {
|
||||
@ -124,96 +107,3 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_inference(example, processor):
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": example["input_prompt_image_base64"] # Placeholder
|
||||
},
|
||||
{"type": "text", "text": example["input_prompt_text"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Apply chat template to get the text
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
|
||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
||||
# 2048 as we passed to OpenAI is too large for training
|
||||
width, height = main_image.size
|
||||
if 1800 <= max(width, height) <= 2200:
|
||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
||||
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"][0]
|
||||
|
||||
# All columns will participate in attention fully
|
||||
attention_mask = np.ones_like(input_ids)
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_qwen2_inference(batch, processor):
|
||||
# Process each example in the batch using the helper function
|
||||
processed_examples = []
|
||||
for i in range(len(batch["input_prompt_image_base64"])):
|
||||
example = {
|
||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
||||
"input_prompt_text": batch["input_prompt_text"][i],
|
||||
"raw_page_text": batch["raw_page_text"][i],
|
||||
}
|
||||
processed_example = prepare_data_for_qwen2_inference(example, processor)
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return {
|
||||
"input_ids": [x["input_ids"] for x in processed_examples],
|
||||
"attention_mask": [x["attention_mask"] for x in processed_examples],
|
||||
"pixel_values": [x["pixel_values"] for x in processed_examples],
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
# Define a custom data collator
|
||||
class DataCollatorForVisionLanguageModeling:
|
||||
def __init__(self, processor):
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features):
|
||||
input_ids = [f['input_ids'] for f in features]
|
||||
attention_mask = [f['attention_mask'] for f in features]
|
||||
labels = [f['labels'] for f in features]
|
||||
pixel_values = [f['pixel_values'] for f in features]
|
||||
|
||||
# Pad input_ids, attention_mask, labels
|
||||
batch = self.processor.pad(
|
||||
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
# Stack pixel_values
|
||||
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
||||
|
||||
return batch
|
||||
|
||||
120
pdelfin/viewer/dolmaviewer.py
Normal file
120
pdelfin/viewer/dolmaviewer.py
Normal file
@ -0,0 +1,120 @@
|
||||
import os
|
||||
import json
|
||||
import html
|
||||
import argparse
|
||||
import boto3
|
||||
import tempfile
|
||||
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
|
||||
from jinja2 import Template
|
||||
import smart_open
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import markdown2
|
||||
|
||||
from pdelfin.s3_utils import get_s3_bytes
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64webp
|
||||
|
||||
def read_jsonl(path):
|
||||
with smart_open.smart_open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
yield line.strip()
|
||||
|
||||
def parse_s3_path(path):
|
||||
# s3://bucket_name/key_name
|
||||
path = path[5:] # Remove 's3://'
|
||||
bucket_name, key_name = path.split('/', 1)
|
||||
return bucket_name, key_name
|
||||
|
||||
def generate_presigned_url(s3_client, bucket_name, key_name):
|
||||
try:
|
||||
response = s3_client.generate_presigned_url('get_object',
|
||||
Params={'Bucket': bucket_name, 'Key': key_name},
|
||||
ExpiresIn=3600) # Link expires in 1 hour
|
||||
return response
|
||||
except (NoCredentialsError, PartialCredentialsError):
|
||||
print("Error: AWS credentials not found or incomplete.")
|
||||
return None
|
||||
|
||||
def process_document(data, s3_client, template, output_dir):
|
||||
id_ = data.get('id')
|
||||
text = data.get('text', '')
|
||||
attributes = data.get('attributes', {})
|
||||
pdf_page_numbers = attributes.get('pdf_page_numbers', [])
|
||||
metadata = data.get('metadata', {})
|
||||
source_file = metadata.get('Source-File')
|
||||
|
||||
# Generate base64 image of the corresponding PDF page
|
||||
if source_file and source_file.startswith('s3://'):
|
||||
local_pdf = tempfile.NamedTemporaryFile("wb+", suffix=".pdf")
|
||||
local_pdf.write(get_s3_bytes(s3_client, source_file))
|
||||
local_pdf.flush()
|
||||
else:
|
||||
raise ValueError("Expecting s3 files only")
|
||||
|
||||
pages = []
|
||||
for span in pdf_page_numbers:
|
||||
start_index, end_index, page_num = span
|
||||
page_text = text[start_index:end_index]
|
||||
|
||||
# Detect and convert Markdown to HTML
|
||||
page_text = html.escape(page_text, quote=True).replace('<br>', '<br>')
|
||||
page_text = markdown2.markdown(page_text, extras=["tables"])
|
||||
|
||||
base64_image = render_pdf_to_base64webp(local_pdf.name, page_num)
|
||||
|
||||
pages.append({'page_num': page_num, 'text': page_text, 'image': base64_image})
|
||||
|
||||
local_pdf.close()
|
||||
|
||||
# Generate pre-signed URL if source_file is an S3 path
|
||||
s3_link = None
|
||||
if source_file and source_file.startswith('s3://'):
|
||||
bucket_name, key_name = parse_s3_path(source_file)
|
||||
s3_link = generate_presigned_url(s3_client, bucket_name, key_name)
|
||||
|
||||
# Render the HTML using the Jinja template
|
||||
html_content = template.render(id=id_, pages=pages, s3_link=s3_link)
|
||||
|
||||
# Write the HTML content to a file
|
||||
filename = f'{id_}.html'
|
||||
filepath = os.path.join(output_dir, filename)
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
def main(jsonl_path, output_dir, template_path):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Load the Jinja template
|
||||
with open(os.path.join(os.path.dirname(__file__), template_path), 'r', encoding='utf-8') as template_file:
|
||||
template_content = template_file.read()
|
||||
template = Template(template_content)
|
||||
|
||||
# Initialize S3 client for generating presigned URLs
|
||||
workspace_session = boto3.Session(profile_name="s2")
|
||||
s3_client = workspace_session.client("s3")
|
||||
|
||||
# Create ThreadPoolExecutor
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for line in read_jsonl(jsonl_path):
|
||||
if not line:
|
||||
continue
|
||||
data = json.loads(line)
|
||||
future = executor.submit(process_document, data, s3_client, template, output_dir)
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Generate HTML pages from a JSONL file with pre-signed S3 links.')
|
||||
parser.add_argument('jsonl_path', help='Path to the JSONL file (local or s3://)')
|
||||
parser.add_argument('--output_dir', default='dolma_previews', help='Directory to save HTML files')
|
||||
parser.add_argument('--template_path', default='dolmaviewer_template.html', help='Path to the Jinja2 template file')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.jsonl_path, args.output_dir, args.template_path)
|
||||
128
pdelfin/viewer/dolmaviewer_template.html
Normal file
128
pdelfin/viewer/dolmaviewer_template.html
Normal file
@ -0,0 +1,128 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{{ id }}</title>
|
||||
<style>
|
||||
/* CSS styles */
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f0f0f0;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
.document {
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
margin: 20px;
|
||||
width: 60%;
|
||||
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
|
||||
line-height: 1.8;
|
||||
}
|
||||
.page-section {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
margin-bottom: 20px;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
.page-section:hover {
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
.page-section .text {
|
||||
flex: 2;
|
||||
padding: 10px;
|
||||
text-align: justify;
|
||||
}
|
||||
.page-section .image {
|
||||
flex: 1;
|
||||
padding: 10px;
|
||||
}
|
||||
.page-section img {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
border: 1px solid #ccc;
|
||||
}
|
||||
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse; /* Ensures that borders are collapsed to give a clean look */
|
||||
margin-bottom: 1.5em; /* Adds some space below the table */
|
||||
}
|
||||
|
||||
th, td {
|
||||
border: 1px solid #ddd; /* 1px solid border for table cells */
|
||||
padding: 12px 15px; /* Adds some padding for better spacing inside the cells */
|
||||
text-align: left; /* Aligns text to the left */
|
||||
vertical-align: top; /* Aligns content to the top of the cell */
|
||||
font-size: 14px; /* Adjusts font size for readability */
|
||||
}
|
||||
|
||||
th {
|
||||
background-color: #f4f4f4; /* Light background for table headers */
|
||||
font-weight: bold; /* Bolds header text */
|
||||
text-transform: uppercase; /* Makes header text uppercase */
|
||||
letter-spacing: 0.05em; /* Adds slight spacing between letters for readability */
|
||||
border-bottom: 2px solid #ccc; /* Slightly thicker bottom border for headers */
|
||||
}
|
||||
|
||||
tr:nth-child(even) {
|
||||
background-color: #f9f9f9; /* Alternates row background color */
|
||||
}
|
||||
|
||||
tr:hover {
|
||||
background-color: #f1f1f1; /* Highlights row on hover for better interaction */
|
||||
}
|
||||
|
||||
td img {
|
||||
max-width: 100%; /* Ensures any images in table cells scale properly */
|
||||
height: auto;
|
||||
display: block;
|
||||
}
|
||||
|
||||
table caption {
|
||||
caption-side: bottom; /* Position caption at the bottom of the table */
|
||||
text-align: right;
|
||||
font-size: 12px;
|
||||
color: #777;
|
||||
padding: 5px 0;
|
||||
}
|
||||
|
||||
</style>
|
||||
|
||||
<script type="text/javascript">
|
||||
window.MathJax = {
|
||||
tex: {
|
||||
inlineMath: [['$', '$'], ['\\(', '\\)']],
|
||||
displayMath: [['$$', '$$'], ['\\[', '\\]']]
|
||||
},
|
||||
options: {
|
||||
skipHtmlTags: ['script', 'noscript', 'style', 'textarea', 'pre'],
|
||||
processHtmlClass: 'mathjax-process' // Class name for areas where LaTeX should be processed
|
||||
}
|
||||
};
|
||||
</script>
|
||||
<script type="text/javascript" id="MathJax-script" async
|
||||
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="document">
|
||||
{% for page in pages %}
|
||||
<div class="page-section" id="page-{{ page.page_num }}">
|
||||
<div class="text">{{ page.text|safe }}</div>
|
||||
{% if page.image %}
|
||||
<div class="image">
|
||||
<a href="{{ s3_link }}#page={{ page.page_num }}" target="_blank">
|
||||
<img src="data:image/webp;base64,{{ page.image }}" alt="Page {{ page.page_num }} Image">
|
||||
</a>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@ -28,7 +28,8 @@ dependencies = [
|
||||
"Pillow",
|
||||
"ftfy",
|
||||
"bleach",
|
||||
"duckdb",
|
||||
"markdown2",
|
||||
"filelock",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
|
||||
BIN
tests/gnarly_pdfs/discoverworld_crazy_tables.pdf
Normal file
BIN
tests/gnarly_pdfs/discoverworld_crazy_tables.pdf
Normal file
Binary file not shown.
2002
tests/gnarly_pdfs/tobacco_missed_tokens_pg1.pdf
Normal file
2002
tests/gnarly_pdfs/tobacco_missed_tokens_pg1.pdf
Normal file
File diff suppressed because it is too large
Load Diff
@ -2,11 +2,12 @@ import unittest
|
||||
import os
|
||||
import json
|
||||
import io
|
||||
import glob
|
||||
|
||||
from pypdf import PdfReader
|
||||
|
||||
from pdelfin.prompts.anchor import _pdf_report, _linearize_pdf_report, get_anchor_text
|
||||
|
||||
from pdelfin.data.renderpdf import get_pdf_media_box_width_height
|
||||
|
||||
class AnchorTest(unittest.TestCase):
|
||||
def testExtractText(self):
|
||||
@ -102,7 +103,14 @@ class AnchorTest(unittest.TestCase):
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 4000)
|
||||
|
||||
def testTobaccoPaperMissingParagraphs(self):
|
||||
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
|
||||
|
||||
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport")
|
||||
|
||||
print(anchor_text)
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 4000)
|
||||
|
||||
|
||||
class BuildSilverTest(unittest.TestCase):
|
||||
@ -124,4 +132,17 @@ class BuildSilverTest(unittest.TestCase):
|
||||
|
||||
print(width, height)
|
||||
|
||||
assert max(width, height) == 2048
|
||||
assert max(width, height) == 2048
|
||||
|
||||
class TestRenderPdf(unittest.TestCase):
|
||||
def testFastMediaBoxMatchesPyPdf(self):
|
||||
for file in glob.glob(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "*.pdf")):
|
||||
reader = PdfReader(file)
|
||||
print("checking", file)
|
||||
|
||||
for page_num in range(1, len(reader.pages) + 1):
|
||||
w1, h1 = get_pdf_media_box_width_height(file, page_num)
|
||||
pypdfpage = reader.pages[page_num - 1]
|
||||
|
||||
self.assertEqual(w1, pypdfpage.mediabox.width)
|
||||
self.assertEqual(h1, pypdfpage.mediabox.height)
|
||||
@ -6,14 +6,13 @@ from functools import partial
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from pdelfin.train.dataloader import (
|
||||
build_batch_query_response_vision_dataset,
|
||||
extract_openai_batch_query,
|
||||
build_finetuning_dataset,
|
||||
extract_openai_batch_response,
|
||||
load_jsonl_into_ds,
|
||||
list_dataset_files
|
||||
)
|
||||
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
||||
|
||||
|
||||
class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
@ -24,52 +23,35 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
print(ds)
|
||||
print(ds["train"])
|
||||
|
||||
def testCombinedQueryResponse(self):
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
def testFinetuningDS(self):
|
||||
ds = build_finetuning_dataset(
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||
)
|
||||
|
||||
print(ds)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
|
||||
|
||||
print(ds[0])
|
||||
|
||||
def testLocalDS(self):
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl",
|
||||
response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json",
|
||||
)
|
||||
|
||||
print(ds)
|
||||
|
||||
ds.to_parquet("/root/trainds_parquet/bigds.parquet")
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
|
||||
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
|
||||
|
||||
print(ds[0])
|
||||
|
||||
def testPlotSequenceLengthHistogram(self):
|
||||
import plotly.express as px
|
||||
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
ds = build_finetuning_dataset(
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
initial_len = len(ds)
|
||||
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200))
|
||||
|
||||
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False)
|
||||
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
|
||||
|
||||
max_seen_len = 0
|
||||
steps = 0
|
||||
@ -98,43 +80,3 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
)
|
||||
|
||||
fig.write_image("sequence_lengths_histogram.png")
|
||||
|
||||
def testExtractBatch(self):
|
||||
query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
||||
|
||||
print(query_data)
|
||||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
|
||||
|
||||
def testExtractResponse(self):
|
||||
response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
|
||||
response_data = response_data["train"]
|
||||
|
||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
|
||||
|
||||
print(response_data)
|
||||
print(response_data[0])
|
||||
|
||||
def testPyArrowDirectJson(self):
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl"
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json"
|
||||
|
||||
all_files = list_dataset_files(query_glob_path)
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.json as paj
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as fs
|
||||
|
||||
s3 = fs.S3FileSystem()
|
||||
|
||||
block_size = 200 * 1024**2
|
||||
|
||||
for file in all_files:
|
||||
with s3.open_input_stream(file.replace("s3://", "")) as f:
|
||||
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
|
||||
|
||||
print(f"file {file} rows {table.num_rows}")
|
||||
print(table.schema)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user