Merge branch 'main' of https://github.com/allenai/pdelfin into main

This commit is contained in:
Jake Poznanski 2024-10-16 11:38:33 -07:00
commit 202d81cece
18 changed files with 3265 additions and 1041 deletions

4
.gitignore vendored
View File

@ -2,7 +2,9 @@
wandb/
*histogram.png
*.json
dolma_previews/*
s2_previews/*
gnarly_previews/*
/*.html

View File

@ -1,320 +0,0 @@
import os
import hashlib
import boto3
import sqlite3
import json
import argparse
import glob
import tempfile
import posixpath
from dataclasses import dataclass
from pypdf import PdfReader
from tqdm import tqdm
from typing import Optional, List, Tuple, Dict
from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text
# Global s3 client for the whole script, feel free to adjust params if you need it
s3 = boto3.client('s3')
class DatabaseManager:
def __init__(self, s3_workspace: str):
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
os.makedirs(home_cache_dir, exist_ok=True)
self.db_path = os.path.join(home_cache_dir, 'index.db')
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self._initialize_tables()
def _initialize_tables(self):
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS page_results (
s3_path TEXT,
page_num INTEGER,
start_index BIGINT,
length BIGINT,
finish_reason TEXT,
error TEXT
)
""")
self.cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_path ON page_results(s3_path)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS pdfs (
s3_path TEXT PRIMARY KEY,
num_pages INTEGER,
status TEXT DEFAULT 'pending'
)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY,
etag TEXT
)
""")
# Generic metadata such as current round
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT
)
""")
self.conn.commit()
def get_metadata(self, key: str) -> Optional[str]:
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
result = self.cursor.fetchone()
return result[0] if result else None
def get_current_round(self):
round_value = self.get_metadata("round")
return int(round_value) if round_value else 0
def is_file_processed(self, s3_path, etag):
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result is not None and result[0] == etag
def add_index_entries(self, index_entries: List['BatchInferenceLine']):
if index_entries:
self.cursor.executemany("""
INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error)
VALUES (?, ?, ?, ?, ?, ?)
""", [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
self.conn.commit()
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
def pdf_exists(self, s3_path: str) -> bool:
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
return self.cursor.fetchone() is not None
def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
try:
self.cursor.execute("""
INSERT INTO pdfs (s3_path, num_pages, status)
VALUES (?, ?, ?)
""", (s3_path, num_pages, status))
self.conn.commit()
except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.")
def get_pdf_status(self, s3_path: str) -> Optional[str]:
self.cursor.execute("SELECT status FROM pdfs WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result[0] if result else None
def close(self):
self.conn.close()
@dataclass(frozen=True)
class BatchInferenceLine:
s3_path: str
page_num: int # 1 indexed!
start_index: int
length: int
finish_reason: str
error: Optional[str]
def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'):
raise ValueError('s3_path must start with s3://')
path = s3_path[5:]
bucket, _, prefix = path.partition('/')
return bucket, prefix
def expand_s3_glob(s3_glob: str) -> Dict[str, str]:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
matched_files = {}
for page in page_iterator:
for obj in page.get('Contents', []):
key = obj['Key']
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files[f"s3://{bucket_name}/{key}"] = obj['ETag'].strip('"')
return matched_files
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
return {
"custom_id": f"{pretty_pdf_path}-{page}",
"chat_messages": [
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
],
"temperature": 0.8,
"max_tokens": 6000,
}
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None or end_index is not None:
range_value = f"bytes={start_index or 0}-"
if end_index is not None:
range_value += str(end_index)
range_header = {'Range': range_value}
if range_header:
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else:
obj = s3.get_object(Bucket=bucket, Key=key)
return obj['Body'].read()
def parse_custom_id(custom_id: str) -> Tuple[str, int]:
s3_path = custom_id[:custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1:])
return s3_path, page_num
def process_jsonl_content(s3_path) -> List[BatchInferenceLine]:
content = get_s3_bytes(s3_path).decode("utf-8")
start_index = 0
index_entries = []
lines = content.splitlines(keepends=True)
for line in lines:
line_length = len(line)
try:
data = json.loads(line)
s3_path, page_num = parse_custom_id(data["custom_id"])
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
index_entries.append(BatchInferenceLine(
s3_path=s3_path,
page_num=page_num,
start_index=start_index,
length=line_length,
finish_reason=data["outputs"][0]["finish_reason"],
error=data.get("completion_error", None)
))
except json.JSONDecodeError:
pass # Handle JSON decode errors if necessary
except Exception as e:
print(f"Error processing line: {e}")
start_index += line_length
return index_entries
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(s3_path))
tf.flush()
reader = PdfReader(tf.name)
return reader.get_num_pages()
except Exception as ex:
print(f"Warning, could not add {s3_path} due to {ex}")
return None
def get_pdf_batch_inference_lines(s3_path: str, pages: list[int]) -> list[dict]:
results = []
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(s3_path))
tf.flush()
for page in pages:
results.append(build_page_query(tf.name, s3_path, page))
except Exception as ex:
print(f"Warning, could not get batch inferences lines for {s3_path} due to {ex}")
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
args = parser.parse_args()
db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}")
print(f"Current round is {db.get_current_round()}\n")
# One shared executor to rule them all
executor = ProcessPoolExecutor()
# If you have new PDFs, add them to the list
if args.add_pdfs:
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
print(f"Querying all PDFs at {args.add_pdfs}")
all_pdfs = expand_s3_glob(args.add_pdfs)
print(f"Found {len(all_pdfs)} total pdf paths")
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs)} total new pdf paths")
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]
num_pages = future.result()
if num_pages and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, num_pages, "pending")
print("\n")
# Now build an index of all the pages that were processed within the workspace so far
print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = [
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
if not db.is_file_processed(s3_path, etag)
]
print(f"Found {len(inference_output_paths)} new batch inference results to index")
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path, etag = future_to_path[future]
inference_lines = future.result()
db.add_index_entries(inference_lines)
db.update_processed_file(s3_path, etag=etag)
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use

687
pdelfin/birrpipeline.py Normal file
View File

@ -0,0 +1,687 @@
import os
import hashlib
import boto3
import sqlite3
import json
import argparse
import glob
import tempfile
import datetime
import posixpath
import threading
import logging
import boto3.session
import urllib3.exceptions
from dataclasses import dataclass
from pypdf import PdfReader
from tqdm import tqdm
from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, put_s3_bytes
# Global s3 client for the whole script, feel free to adjust params if you need it
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
class DatabaseManager:
@dataclass(frozen=True)
class BatchInferenceRecord:
inference_s3_path: str
pdf_s3_path: str
page_num: int # 1 indexed!
round: int
start_index: int
length: int
finish_reason: str
error: Optional[str]
def is_usable(self):
return self.error is None and self.finish_reason == "stop"
@dataclass(frozen=True)
class PDFRecord:
s3_path: str
num_pages: int
status: str
def __init__(self, s3_workspace: str):
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
os.makedirs(home_cache_dir, exist_ok=True)
self.db_path = os.path.join(home_cache_dir, 'index.db')
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self._initialize_tables()
def _initialize_tables(self):
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS page_results (
inference_s3_path TEXT,
pdf_s3_path TEXT,
page_num INTEGER,
round INTEGER,
start_index BIGINT,
length BIGINT,
finish_reason TEXT,
error TEXT
)
""")
self.cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS pdfs (
s3_path TEXT PRIMARY KEY,
num_pages INTEGER,
status TEXT DEFAULT 'pending'
)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY,
etag TEXT
)
""")
# Generic metadata such as current round
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT
)
""")
self.conn.commit()
def get_metadata(self, key: str) -> Optional[str]:
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
result = self.cursor.fetchone()
return result[0] if result else None
def set_metadata(self, key: str, value: str) -> None:
self.cursor.execute("""
INSERT INTO metadata (key, value)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value
""", (key, value))
self.conn.commit()
def is_file_processed(self, s3_path, etag):
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result is not None and result[0] == etag
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
if index_entries:
self.cursor.executemany("""
INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
self.conn.commit()
def get_index_entries(self, pdf_s3_path: str) -> List[BatchInferenceRecord]:
self.cursor.execute("""
SELECT inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error
FROM page_results
WHERE pdf_s3_path = ?
ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC
""", (pdf_s3_path,))
rows = self.cursor.fetchall()
return [
self.BatchInferenceRecord(
inference_s3_path=row[0],
pdf_s3_path=row[1],
page_num=row[2],
round=row[3],
start_index=row[4],
length=row[5],
finish_reason=row[6],
error=row[7]
)
for row in rows
]
def get_last_indexed_round(self) -> int:
self.cursor.execute("""
SELECT MAX(round)
FROM page_results
""")
result = self.cursor.fetchone()
return -1 if result[0] is None else result[0]
def pdf_exists(self, s3_path: str) -> bool:
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
return self.cursor.fetchone() is not None
def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
try:
self.cursor.execute("""
INSERT INTO pdfs (s3_path, num_pages, status)
VALUES (?, ?, ?)
""", (s3_path, num_pages, status))
self.conn.commit()
except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.")
def update_pdf_status(self, s3_path: str, new_status: str) -> None:
self.cursor.execute("""
UPDATE pdfs
SET status = ?
WHERE s3_path = ?
""", (new_status, s3_path))
self.conn.commit()
def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
self.cursor.execute("""
SELECT s3_path, num_pages, status
FROM pdfs
WHERE s3_path = ?
""", (s3_path,))
row = self.cursor.fetchone()
if row:
return self.PDFRecord(
s3_path=row[0],
num_pages=row[1],
status=row[2]
)
return None
def get_pdfs_by_status(self, status: str) -> List[PDFRecord]:
self.cursor.execute("""
SELECT s3_path, num_pages, status
FROM pdfs
WHERE status == ?
ORDER BY s3_path DESC, num_pages DESC
""", (status, ))
rows = self.cursor.fetchall()
return [
self.PDFRecord(
s3_path=row[0],
num_pages=row[1],
status=row[2]
)
for row in rows
]
def close(self):
self.conn.close()
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = []
self.batch_size = 0
self.after_flush = after_flush
self.threads = []
parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
if not self.is_s3:
os.makedirs(output_prefix, exist_ok=True)
def _compute_hash(self, content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
return sha1.hexdigest()[:20]
def _get_output_path(self, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(self.output_prefix)
if self.is_s3:
bucket = parsed.netloc
key = parsed.path.lstrip('/')
if key and not key.endswith('/'):
key += '/'
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
filename = f"output_{hash_str}.jsonl"
return os.path.join(self.output_prefix, filename)
def write_line(self, line: Optional[str]):
if line is None or not line.strip():
return
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
self.batch.append(line)
self.batch_size += line_size
def _write_batch(self):
if not self.batch:
return
batch_lines = self.batch.copy()
batch_content = "\n".join(batch_lines) + "\n"
hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str)
# Start a new thread to write the batch
thread = threading.Thread(
target=self._write_batch_to_file,
args=(batch_content, output_path, batch_lines)
)
thread.start()
self.threads.append(thread)
# Clear the batch and batch_size
self.batch = []
self.batch_size = 0
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
if self.is_s3:
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
else:
with open(output_path, 'w', encoding='utf-8') as f_out:
f_out.write(batch_content)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_lines)
def close(self):
self._write_batch()
# Wait for all threads to finish
for thread in self.threads:
thread.join()
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
return {
"custom_id": f"{pretty_pdf_path}-{page}",
"chat_messages": [
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
],
}
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
start_index = 0
index_entries = []
lines = content_bytes.splitlines(keepends=True) # Split content into lines as bytes
for line in lines:
line_length = len(line) # Length in bytes
try:
# Decode the line for JSON processing
line_str = line.decode('utf-8')
data = json.loads(line_str)
pdf_s3_path, page_num = parse_custom_id(data["custom_id"])
if data.get("completion_error", None) is not None:
index_entries.append(DatabaseManager.BatchInferenceRecord(
inference_s3_path=inference_s3_path,
pdf_s3_path=pdf_s3_path,
page_num=page_num,
round=data["round"],
start_index=start_index, # Byte offset in the original file
length=line_length, # Length in bytes
finish_reason="completion_error",
error=data.get("completion_error", None)
))
else:
# Try to parse the actual model response JSON
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
try:
model_response_json = json.loads(data["outputs"][0]["text"])
index_entries.append(DatabaseManager.BatchInferenceRecord(
inference_s3_path=inference_s3_path,
pdf_s3_path=pdf_s3_path,
page_num=page_num,
round=data["round"],
start_index=start_index, # Byte offset in the original file
length=line_length, # Length in bytes
finish_reason=data["outputs"][0]["finish_reason"],
error=data.get("completion_error", None)
))
except json.JSONDecodeError:
index_entries.append(DatabaseManager.BatchInferenceRecord(
inference_s3_path=inference_s3_path,
pdf_s3_path=pdf_s3_path,
page_num=page_num,
round=data["round"],
start_index=start_index, # Byte offset in the original file
length=line_length, # Length in bytes
finish_reason=data["outputs"][0]["finish_reason"],
error="Could not parse model JSON output",
))
except json.JSONDecodeError:
print(f"Error with JSON Decoding of inference in {inference_s3_path}")
# TODO Maybe this needs to add an index error that this json is bad
except Exception as e:
print(f"Error processing line: {e}")
start_index += line_length # Increment by the number of bytes
return index_entries
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(pdf_s3, s3_path))
tf.flush()
reader = PdfReader(tf.name)
return reader.get_num_pages()
except Exception as ex:
print(f"Warning, could not add {s3_path} due to {ex}")
return None
def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int) -> list[dict]:
db = DatabaseManager(s3_workspace)
existing_pages = db.get_index_entries(pdf.s3_path)
new_queries = []
# Shortcut out of downloading the actual PDF
if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)):
return []
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
tf.flush()
for target_page_num in range(1, pdf.num_pages + 1):
# Is there an existing page that has no error
if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages):
continue
has_errored_previously = sum(page.page_num == target_page_num for page in existing_pages)
if has_errored_previously:
# Retry the page up to 3 times
for _ in range(3):
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num), "round": cur_round})
# Optionally, you can implement more complex retry logic here
else:
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num), "round": cur_round})
except Exception as ex:
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
return new_queries
def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Optional[dict]:
db = DatabaseManager(s3_workspace)
existing_pages = db.get_index_entries(pdf.s3_path)
document_text = ""
last_page_start_index = 0
pdf_page_spans = []
for target_page_num in range(1, pdf.num_pages + 1):
target_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num]
if len(target_pages) == 0:
return None
target_page = target_pages[0]
target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path,
start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length - 1)
target_data = json.loads(target_row.decode("utf-8"))
target_output = json.loads(target_data["outputs"][0]["text"])
if target_output["natural_text"] is not None:
document_text += target_output["natural_text"] + "\n"
pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
last_page_start_index = len(document_text)
metadata = {
"Source-File": pdf.s3_path,
"pdf-total-pages": pdf.num_pages,
}
id_ = hashlib.sha1(document_text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": document_text,
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}
}
return dolma_doc
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
db = DatabaseManager(s3_workspace)
for line in dolma_doc_lines:
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
def get_current_round(s3_workspace: str) -> int:
path = s3_workspace[5:]
bucket, _, prefix = path.partition('/')
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
paginator = workspace_s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
round_numbers = []
for page in page_iterator:
for common_prefix in page.get('CommonPrefixes', []):
round_prefix = common_prefix.get('Prefix')
# Extract 'round_X' from the prefix
round_dir = posixpath.basename(posixpath.dirname(round_prefix))
if round_dir.startswith('round_'):
try:
round_num = int(round_dir[len('round_'):])
round_numbers.append(round_num)
except ValueError:
pass
if round_numbers:
current_round = max(round_numbers) + 1
else:
current_round = 0
return current_round
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
args = parser.parse_args()
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}")
current_round = get_current_round(args.workspace)
print(f"Current round is {current_round}\n")
# One shared executor to rule them all
executor = ProcessPoolExecutor()
# If you have new PDFs, step one is to add them to the list
if args.add_pdfs:
if args.add_pdfs.startswith("s3://"):
print(f"Querying all PDFs at {args.add_pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
print(f"Found {len(all_pdfs):,} total pdf paths")
elif os.path.exists(args.add_pdfs):
with open(args.add_pdfs, "r") as f:
all_pdfs = [line.strip() for line in f.readlines() if len(line.strip()) > 0]
else:
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]
num_pages = future.result()
if num_pages and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, num_pages, "pending")
print("\n")
# Now build an index of all the pages that were processed within the workspace so far
print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = [
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
if not db.is_file_processed(s3_path, etag)
]
print(f"Found {len(inference_output_paths):,} new batch inference results to index")
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path, etag = future_to_path[future]
try:
inference_records = future.result()
db.add_index_entries(inference_records)
db.update_processed_file(s3_path, etag=etag)
except urllib3.exceptions.SSLError:
print(f"Cannot load inference file {s3_path} due to SSL error, will retry another time")
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
if db.get_last_indexed_round() < current_round - 1:
print(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}")
potentially_done_pdfs = db.get_pdfs_by_status("pending")
else:
print(f"\nCreating batch inference files for new PDFs")
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round): pdf for pdf in db.get_pdfs_by_status("pending")}
potentially_done_pdfs = []
lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
inference_lines = future.result()
if len(inference_lines) == 0:
potentially_done_pdfs.append(pdf)
for line in inference_lines:
lines_written += 1
if line is not None:
new_inference_writer.write_line(json.dumps(line))
new_inference_writer.close()
if lines_written > 0:
print(f"Added {lines_written:,} new batch inference requests")
# Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
dolma_doc = future.result()
if dolma_doc is not None:
new_output_writer.write_line(json.dumps(dolma_doc))
new_output_writer.close()
print("\nFinal statistics:")
# Output the number of PDFs in each status "pending" and "completed"
pending_pdfs = db.get_pdfs_by_status("pending")
completed_pdfs = db.get_pdfs_by_status("completed")
print(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)")
print(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)")
# For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason)
total_rounds = db.get_last_indexed_round() + 1
for round_num in range(total_rounds):
db.cursor.execute("""
SELECT COUNT(*), error, finish_reason
FROM page_results
WHERE round = ?
GROUP BY error, finish_reason
""", (round_num,))
results = db.cursor.fetchall()
total_pages = sum(count for count, _, _ in results)
print(f"\nInference Round {round_num} - {total_pages:,} pages processed:")
for count, error, finish_reason in results:
error_str = error if error is not None else "None"
print(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages")
print("\nWork finished, waiting for all workers to finish cleaning up")
executor.shutdown(wait=True)
db.close()

View File

@ -2,14 +2,42 @@ import subprocess
import base64
import io
from pypdf import PdfReader
from PIL import Image
def render_pdf_to_base64png(local_pdf_path: str, page: int, target_longest_image_dim: int=2048):
pdf = PdfReader(local_pdf_path)
pdf_page = pdf.pages[page - 1]
longest_dim = max(pdf_page.mediabox.width, pdf_page.mediabox.height)
def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]:
"""
Get the MediaBox dimensions for a specific page in a PDF file using the pdfinfo command.
:param pdf_file: Path to the PDF file
:param page_num: The page number for which to extract MediaBox dimensions
:return: A dictionary containing MediaBox dimensions or None if not found
"""
# Construct the pdfinfo command to extract info for the specific page
command = ['pdfinfo', '-f', str(page_num), '-l', str(page_num), '-box', local_pdf_path]
# Run the command using subprocess
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if there is any error in executing the command
if result.returncode != 0:
raise ValueError(f"Error running pdfinfo: {result.stderr}")
# Parse the output to find MediaBox
output = result.stdout
media_box = None
for line in output.splitlines():
if 'MediaBox' in line:
media_box = line.split(':')[1].strip().split()
media_box = [float(x) for x in media_box]
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
raise ValueError("MediaBox not found in the PDF info.")
def render_pdf_to_base64png(local_pdf_path: str, page_num: int, target_longest_image_dim: int=2048):
longest_dim = max(get_pdf_media_box_width_height(local_pdf_path, page_num))
# Convert PDF page to PNG using pdftoppm
pdftoppm_result = subprocess.run(
@ -17,9 +45,9 @@ def render_pdf_to_base64png(local_pdf_path: str, page: int, target_longest_image
"pdftoppm",
"-png",
"-f",
str(page),
str(page_num),
"-l",
str(page),
str(page_num),
"-r",
str(target_longest_image_dim * 72 / longest_dim), # 72 pixels per point is the conversion factor
local_pdf_path,
@ -35,8 +63,67 @@ def render_pdf_to_base64png(local_pdf_path: str, page: int, target_longest_image
def render_pdf_to_base64webp(local_pdf_path: str, page: int, target_longest_image_dim: int=1024):
base64_png = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim)
png_image = Image.open(io.BytesIO(base64_png.encode("utf-8")))
png_image = Image.open(io.BytesIO(base64.b64decode(base64_png)))
webp_output = io.BytesIO()
png_image.save(webp_output, format="WEBP")
return base64.b64encode(webp_output.getvalue()).decode("utf-8")
return base64.b64encode(webp_output.getvalue()).decode("utf-8")
def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
"""
Returns the (width, height) of a PNG image given its base64-encoded data,
without base64-decoding the entire data or loading the PNG itself
Should be really fast to support filtering
Parameters:
- base64_data (str): Base64-encoded PNG image data.
Returns:
- tuple: (width, height) of the image.
Raises:
- ValueError: If the data is not a valid PNG image or the required bytes are not found.
"""
# PNG signature is 8 bytes
png_signature_base64 = base64.b64encode(b'\x89PNG\r\n\x1a\n').decode('ascii')
if not base64_data.startswith(png_signature_base64[:8]):
raise ValueError('Not a valid PNG file')
# Positions in the binary data where width and height are stored
width_start = 16 # Byte position where width starts (0-based indexing)
width_end = 20 # Byte position where width ends (exclusive)
height_start = 20
height_end = 24
# Compute the byte range needed (from width_start to height_end)
start_byte = width_start
end_byte = height_end
# Calculate base64 character positions
# Each group of 3 bytes corresponds to 4 base64 characters
base64_start = (start_byte // 3) * 4
base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups
# Extract the necessary base64 substring
base64_substring = base64_data[base64_start:base64_end]
# Decode only the necessary bytes
decoded_bytes = base64.b64decode(base64_substring)
# Compute the offset within the decoded bytes
offset = start_byte % 3
# Extract width and height bytes
width_bytes = decoded_bytes[offset:offset+4]
height_bytes = decoded_bytes[offset+4:offset+8]
if len(width_bytes) < 4 or len(height_bytes) < 4:
raise ValueError('Insufficient data to extract dimensions')
# Convert bytes to integers
width = int.from_bytes(width_bytes, 'big')
height = int.from_bytes(height_bytes, 'big')
return width, height

View File

@ -10,9 +10,11 @@
# coherency score best of these three
import subprocess
import re
import random
import ftfy
from dataclasses import dataclass
from typing import Literal, List
from functools import lru_cache
import pypdfium2 as pdfium
import pymupdf
@ -24,7 +26,7 @@ from pypdf.generic import RectangleObject
from pdelfin.prompts._adv_anchor import mult
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"]) -> str:
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int=4000) -> str:
assert page > 0, "Pages are 1-indexed in pdf-land"
if pdf_engine == "pdftotext":
@ -52,7 +54,7 @@ def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftote
return best_option
elif pdf_engine == "pdfreport":
return _linearize_pdf_report(_pdf_report(local_pdf_path, page))
return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
else:
raise NotImplementedError("Unknown engine")
@ -119,10 +121,14 @@ class PageReport:
text_elements: List[TextElement]
image_elements: List[ImageElement]
@lru_cache(maxsize=5)
def _get_cached_pdf_reader(local_pdf_path: str) -> PdfReader:
# Cached, because you are going to often iterate through a whole pdf, so this will make it a lot faster on subsequent iterations
return PdfReader(local_pdf_path)
def _pdf_report(local_pdf_path: str, page: int) -> PageReport:
reader = PdfReader(local_pdf_path)
page = reader.pages[page - 1]
def _pdf_report(local_pdf_path: str, page_num: int) -> PageReport:
reader = _get_cached_pdf_reader(local_pdf_path)
page = reader.pages[page_num - 1]
resources = page.get("/Resources", {})
xobjects = resources.get("/XObject", {})
text_elements, image_elements = [], []
@ -330,7 +336,10 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
]
# Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
# remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
# Shuffle remaining elements randomly
random.shuffle(remaining_elements)
# Add elements until reaching max_length
for elem_type, elem, s, position in remaining_elements:

View File

@ -1,275 +0,0 @@
# The way this script works is it gets a list of pdfs to process
# and an output/scratch folder location either locally or in s3 to work with
# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM
# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location
# Then, you run your script again, and it will see that you have some valid output files
# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done
# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length)
# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again
# This process will keep going, until you run it with the --fallback option, at which point it will
# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs
#
#
#
import os
import glob
import random
import argparse
import boto3
import json
import hashlib
from pypdf import PdfReader
from tqdm import tqdm
from typing import Generator, List
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from urllib.parse import urlparse
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.filter import PdfFilter
import logging
import smart_open
import posixpath # Import posixpath for S3 path handling
logging.getLogger("pypdf").setLevel(logging.ERROR)
pdf_filter = PdfFilter()
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
return {
"custom_id": f"{pretty_pdf_path}-{page}",
"chat_messages": [
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
],
"temperature": 0.1,
"max_tokens": 6000,
}
def fetch_s3_file(s3_url: str, local_path: str) -> str:
parsed = urlparse(s3_url)
bucket_name = parsed.netloc
key = parsed.path.lstrip('/')
s3 = boto3.client('s3')
s3.download_file(bucket_name, key, local_path)
return local_path
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
if pdf_path.startswith("s3://"):
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
fetch_s3_file(pdf_path, local_pdf_path)
else:
local_pdf_path = pdf_path
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
print(f"Skipping {local_pdf_path} due to common filter")
return []
pretty_pdf_path = pdf_path
pdf = PdfReader(local_pdf_path)
num_pages = len(pdf.pages)
sample_pages = list(range(1, num_pages + 1))
result = []
for page in sample_pages:
try:
query = build_page_query(local_pdf_path, pretty_pdf_path, page)
result.append(query)
except Exception as e:
print(f"Error processing page {page} of {pdf_path}: {e}")
return result
def is_glob_pattern(path: str) -> bool:
return any(char in path for char in ['*', '?', '[', ']'])
def expand_s3_glob(s3_glob: str) -> list:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)
s3 = boto3.client('s3')
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
matched_files = []
for page in page_iterator:
for obj in page.get('Contents', []):
key = obj['Key']
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files.append(f"s3://{bucket_name}/{key}")
return matched_files
def compute_hash(content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
return sha1.hexdigest()[:20]
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(output_path)
if parsed.scheme in ('s3', 's3a', 's3n'):
bucket = parsed.netloc
key = parsed.path.lstrip('/')
# Ensure the key is treated as a directory by appending a slash if not present
if key and not key.endswith('/'):
key += '/'
# Use posixpath to correctly join S3 paths
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
dir_path = output_path
filename = f"output_{hash_str}.jsonl"
return os.path.join(dir_path, filename)
def main():
parser = argparse.ArgumentParser(
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
)
parser.add_argument(
"pdf_paths",
nargs='*',
help=(
"List of PDF paths to process. If a single argument contains glob patterns (e.g., *.pdf or s3://bucket/pdfs/*.pdf), "
"it will be expanded accordingly."
)
)
parser.add_argument(
"--path_list",
type=str,
help="Path to a file containing paths to PDFs, one per line."
)
parser.add_argument(
"--max_size_mb",
type=int,
default=250,
help="Max number of MBs of entries to put in each birr workitem"
)
parser.add_argument(
"--no_filter",
action="store_true",
help="Disables the basic spam/language filtering so that ALL pdfs listed are used"
)
parser.add_argument(
"--output",
type=str,
default="mise_batch_data",
help="Output destination (can be a local path or an S3 URI)"
)
args = parser.parse_args()
pdf_paths = []
# Load PDF paths from positional arguments or path_list
if args.pdf_paths:
for path in args.pdf_paths:
if is_glob_pattern(path):
glob_path = path
if glob_path.startswith("s3://"):
# Handle S3 globbing
expanded_paths = expand_s3_glob(glob_path)
pdf_paths.extend(expanded_paths)
else:
# Handle local filesystem globbing
expanded_paths = glob.glob(glob_path, recursive=True)
pdf_paths.extend(expanded_paths)
else:
pdf_paths.append(path)
if args.path_list:
with open(args.path_list, 'r') as f:
for line in f:
path = line.strip()
if path:
pdf_paths.append(path)
# Remove duplicates and shuffle
pdf_paths = list(set(pdf_paths))
random.shuffle(pdf_paths)
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
# Prepare for output
output_dir = args.output
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
# Determine if output is S3
parsed_output = urlparse(output_dir)
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
# Initialize variables for batching
batch = []
batch_size = 0
pdfs_with_output = 0
# Function to write a batch
def write_batch(batch: List[dict]):
nonlocal output_dir
if not batch:
return
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
hash_str = compute_hash(batch_content)
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
with smart_open.open(output_path_with_hash, 'w') as f_out:
f_out.write(batch_content)
print(f"Wrote batch to {output_path_with_hash}")
# Using ProcessPoolExecutor to process files concurrently
with ProcessPoolExecutor() as executor:
futures = []
with tqdm(desc="Processing PDFs", leave=False, total=len(pdf_paths)) as pb:
for pdf_path in pdf_paths:
futures.append(executor.submit(process_pdf, pdf_path, args.no_filter))
for future in as_completed(futures):
try:
request_results = future.result() # Get the result from the process
if request_results:
pdfs_with_output += 1 # Increment if there's at least one result
for request_obj in request_results:
request_json = json.dumps(request_obj)
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
# Check if adding this entry would exceed the max size
if batch_size + request_size > max_file_size:
# Write the current batch
write_batch(batch)
# Reset the batch
batch = []
batch_size = 0
# Add the entry to the batch
batch.append(request_obj)
batch_size += request_size
pb.update(1)
except Exception as e:
print(f"Error processing a PDF: {str(e)}")
# Write any remaining batch
write_batch(batch)
# Print the number of PDFs that resulted in at least one output
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
print(f"Now you should run these prompts through mise/birr")
if __name__ == "__main__":
main()

75
pdelfin/s3_utils.py Normal file
View File

@ -0,0 +1,75 @@
import os
import glob
import posixpath
from typing import Optional
from urllib.parse import urlparse
def parse_s3_path(s3_path: str) -> tuple[str, str]:
if not s3_path.startswith('s3://'):
raise ValueError('s3_path must start with s3://')
parsed = urlparse(s3_path)
bucket = parsed.netloc
key = parsed.path.lstrip('/')
return bucket, key
def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)
paginator = s3_client.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
matched_files = {}
for page in page_iterator:
for obj in page.get('Contents', []):
key = obj['Key']
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files[f"s3://{bucket_name}/{key}"] = obj['ETag'].strip('"')
return matched_files
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None and end_index is not None:
# Range: bytes=start_index-end_index
range_value = f"bytes={start_index}-{end_index}"
range_header = {'Range': range_value}
elif start_index is not None and end_index is None:
# Range: bytes=start_index-
range_value = f"bytes={start_index}-"
range_header = {'Range': range_value}
elif start_index is None and end_index is not None:
# Range: bytes=-end_index (last end_index bytes)
range_value = f"bytes=-{end_index}"
range_header = {'Range': range_value}
if range_header:
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else:
obj = s3_client.get_object(Bucket=bucket, Key=key)
return obj['Body'].read()
def put_s3_bytes(s3_client, s3_path: str, data: bytes):
bucket, key = parse_s3_path(s3_path)
s3_client.put_object(
Bucket=bucket,
Key=key,
Body=data,
ContentType='text/plain; charset=utf-8'
)
def parse_custom_id(custom_id: str) -> tuple[str, int]:
s3_path = custom_id[:custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1:])
return s3_path, page_num

View File

@ -7,43 +7,33 @@ wandb:
project: pdelfin
entity: ai2-llm
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
max_length: 8192
train_data:
seed: 1337
sources:
# These tend to be really big, so it's only practical to host them as parquets on weka, otherwise you may OOM or just never finish dataloading
- name: openai_batch_data_v5_1_train
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_parquet/*.parquet
- name: openai_batch_data_v5_1_train
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_parquet/*.parquet
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
valid_data:
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000

View File

@ -22,28 +22,6 @@ class ModelConfig:
model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None)
@dataclass
class FormatConfig:
"""Configuration for formatting the text that is input to the model."""
new_line_symbol: str = field(
help="The symbol to use for new lines in the text; default is '\\n'.",
default="\n",
)
system_message: Optional[str] = field(
help="The system message to use for formatting the text; default is no system message.",
default=None,
)
instruction_template: str = field(
help="The template to use for formatting the input text", default="Original:"
)
response_template: str = field(help="The template to use for formatting the output text", default="Rewrite:")
chat_template: Optional[str] = field(
help="The template to use for formatting the chat text. If None, the default chat template will be used.",
default=None,
)
@dataclass
class GenerateConfig:
max_length: int = field(help="The maximum length of the generated text", default=4096)
@ -75,9 +53,9 @@ class AwsConfig:
@dataclass
class SourceConfig:
name: str = field(help="The name of the source")
parquet_path: Optional[str] = field(help="The s3/glob path to a bunch of parquet files for a preprocessed dataset.", default=None)
query_glob_path: Optional[str] = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data", default=None)
response_glob_path: Optional[str] = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai", default=None)
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
target_longest_image_dim: int = field(help="Dimensions to render the pdf page image to")
target_anchor_text_len: int = field(help="Maximum amount of anchor text (aka prompt hint)")
@dataclass
@ -141,7 +119,6 @@ class TrainConfig:
lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration")
aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3")
wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases")
format: FormatConfig = field(default=FormatConfig(), help="Configuration for formatting the input/output text")
train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data")
valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data")
generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation")
@ -158,5 +135,4 @@ class DemoConfig:
share: bool = field(default=False, help="Share the demo publicly.")
model: ModelConfig = field(default=ModelConfig())
format: FormatConfig = field(default=FormatConfig())
generate: GenerateConfig = field(default=GenerateConfig())

View File

@ -5,21 +5,27 @@ import re
import os
import base64
import glob
import pypdf, pypdf.errors
from functools import partial
from typing import Any, Dict, Optional
from logging import Logger
from filelock import FileLock
import boto3
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
from .core.config import DataConfig, SourceConfig
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
def list_dataset_files(s3_glob_path: str):
"""
@ -67,142 +73,12 @@ def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
return dataset
def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
"""
Returns the (width, height) of a PNG image given its base64-encoded data,
without base64-decoding the entire data or loading the PNG itself
Should be really fast to support filtering
Parameters:
- base64_data (str): Base64-encoded PNG image data.
Returns:
- tuple: (width, height) of the image.
Raises:
- ValueError: If the data is not a valid PNG image or the required bytes are not found.
"""
# PNG signature is 8 bytes
png_signature_base64 = base64.b64encode(b'\x89PNG\r\n\x1a\n').decode('ascii')
if not base64_data.startswith(png_signature_base64[:8]):
raise ValueError('Not a valid PNG file')
# Positions in the binary data where width and height are stored
width_start = 16 # Byte position where width starts (0-based indexing)
width_end = 20 # Byte position where width ends (exclusive)
height_start = 20
height_end = 24
# Compute the byte range needed (from width_start to height_end)
start_byte = width_start
end_byte = height_end
# Calculate base64 character positions
# Each group of 3 bytes corresponds to 4 base64 characters
base64_start = (start_byte // 3) * 4
base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups
# Extract the necessary base64 substring
base64_substring = base64_data[base64_start:base64_end]
# Decode only the necessary bytes
decoded_bytes = base64.b64decode(base64_substring)
# Compute the offset within the decoded bytes
offset = start_byte % 3
# Extract width and height bytes
width_bytes = decoded_bytes[offset:offset+4]
height_bytes = decoded_bytes[offset+4:offset+8]
if len(width_bytes) < 4 or len(height_bytes) < 4:
raise ValueError('Insufficient data to extract dimensions')
# Convert bytes to integers
width = int.from_bytes(width_bytes, 'big')
height = int.from_bytes(height_bytes, 'big')
return width, height
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
"""
custom_id = query.get("custom_id", "")
body = query.get("body", {})
messages = body.get("messages", [])
input_prompt_text = ""
input_prompt_image_base64 = ""
for message in messages:
if message.get("role") != "user":
continue # We are only interested in user messages
contents = message.get("content", [])
for content_item in contents:
if content_item.get("type") == "text":
input_prompt_text = content_item.get("text", "")
elif content_item.get("type") == "image_url":
image_url = content_item.get("image_url", {}).get("url", "")
if image_url.startswith("data:image"):
# Extract base64 part from data URL
try:
base64_data = image_url.split(",", 1)[1]
input_prompt_image_base64 = base64_data
except IndexError:
input_prompt_image_base64 = ""
# This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that
# and reusing it
# # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
# # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
# # so we're going to extract out just the raw extracted prompt text
# pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
# # Use re.DOTALL to ensure that the dot matches newline characters
# match = re.search(pattern, input_prompt_text, re.DOTALL)
# if match:
# raw_page_text = match.group(1).strip()
# else:
# raw_page_text = ""
# This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time
s3_path = custom_id[:custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1:])
s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
)
# Split the s3_path into bucket and key
bucket_name = s3_path.split('s3://')[1].split('/')[0]
s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:])
with tempfile.NamedTemporaryFile(delete=False) as tf:
s3_client.download_fileobj(bucket_name, s3_key, tf)
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
return {
"custom_id": custom_id,
"input_prompt_text": input_prompt_text,
"input_prompt_image_base64": input_prompt_image_base64,
"raw_page_text": raw_page_text,
}
def extract_openai_batch_response(example):
custom_id = example.get("custom_id", None)
# Parse the custom id into an s3 document path and page number (1indexed)
s3_path, page_num = parse_custom_id(custom_id)
response_body = example.get("response", {}).get("body", {})
choices = response_body.get("choices", [])
response = ""
@ -213,58 +89,71 @@ def extract_openai_batch_response(example):
response = message.get("content", "")
finish_reason = first_choice.get("finish_reason", "")
return {"custom_id": custom_id, "response": response, "finish_reason": finish_reason}
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
# into its own columns
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]):
custom_id = query_example["custom_id"]
def _cache_s3_file(s3_path: str, local_cache_dir: str):
"""
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
"""
bucket, key = parse_s3_path(s3_path)
if custom_id not in response_map:
return {
"response": None,
"finish_reason": None,
}
# Define the local file path
local_file_path = os.path.join(local_cache_dir, key.replace("/", "_"))
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
lock_file = f"{local_file_path}.lock"
response_row = response_data[response_map[custom_id]]
# Use a file lock to prevent concurrent writes
with FileLock(lock_file):
if not os.path.exists(local_file_path):
logger.info(f"Downloading {s3_path} to {local_file_path}")
s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
)
s3_client.download_file(bucket, key, local_file_path)
else:
logger.info(f"File {local_file_path} already exists, skipping download.")
return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]}
return local_file_path
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
"""
Caches all S3 paths in the dataset to the local cache directory.
"""
# Define the download function to use in parallel processing
def cache_file(example):
s3_path = example["s3_path"]
if s3_path:
# Download the file and cache it locally
local_path = _cache_s3_file(s3_path, pdf_cache_location)
return {"local_pdf_path": local_path}
return {"local_pdf_path": None}
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
logger.info("Loading query and response datasets")
query_data = load_jsonl_into_ds(query_glob_path)
# Map the caching function to the dataset (with parallelism if needed)
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
return dataset
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset:
if pdf_cache_location is None:
pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs')
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
response_data = load_jsonl_into_ds(response_glob_path)
# Map the datasets down to the core fields that we're going to need to make them easier to process
logger.info("Mapping query data")
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
logger.info("Mapping response data")
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
# What we're going to do, is build an in-memory map for the response data from custom_id to row
# This will let us do quick lookups when we do a merge step, but it will not scale past a certain point
logger.info("Building custom_id to row map")
custom_id_to_response_row = {}
for row_id, entry in enumerate(response_data):
custom_id_to_response_row[entry["custom_id"]] = row_id
logger.info("Running merge map")
final_dataset = query_data.map(
partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
num_proc=num_proc
)
# Don't include data where the model cut off due to a length issue, or moderation issue
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
logger.info("Filtering on finish_reason == stop")
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Pick things that have a reasonable image size only
def pick_image_sizes(x):
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
return 1800 <= max(width, height) <= 2200
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
# Cache all the s3_paths that were accessed to a local storage location,
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
return final_dataset

View File

@ -4,19 +4,15 @@ from PIL import Image
import base64
import torch # Make sure to import torch as it's used in the DataCollator
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.prompts import build_finetuning_prompt
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2200, max_response_len: int=2200):
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
return False
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
return False
return True
from pdelfin.data.renderpdf import render_pdf_to_base64png
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int):
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Prepare messages
messages = [
{
@ -24,9 +20,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
"image": base64_page_image
},
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
],
}
]
@ -36,14 +32,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
# Right now, we are going to downsample to 1024 on the longest dimension, because
# 2048 as we passed to OpenAI is too large for training
width, height = main_image.size
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process inputs using processor
inputs = processor(
@ -84,36 +73,30 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
labels_full = np.full_like(input_ids, fill_value=-100)
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
# TODO Maybe cap the max length
# Return as dict, including pixel_values
if add_batch_dim:
return {
"input_ids": input_ids[np.newaxis, ...],
"attention_mask": attention_mask[np.newaxis, ...],
"labels": labels_full[np.newaxis, ...],
"pixel_values": inputs.pixel_values[np.newaxis, ...],
"image_grid_thw": inputs["image_grid_thw"]
}
else:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
def batch_prepare_data_for_qwen2_training(batch, processor):
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["input_prompt_image_base64"])):
for i in range(len(batch["response"])):
example = {
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
"input_prompt_text": batch["input_prompt_text"][i],
"raw_page_text": batch["raw_page_text"][i],
"local_pdf_path": batch["local_pdf_path"][i],
"page_num": batch["page_num"][i],
"response": batch["response"][i]
}
processed_example = prepare_data_for_qwen2_training(example, processor)
processed_example = prepare_data_for_qwen2_training(example, processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len)
processed_examples.append(processed_example)
return {
@ -124,96 +107,3 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
def prepare_data_for_qwen2_inference(example, processor):
# Prepare messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
# Right now, we are going to downsample to 1024 on the longest dimension, because
# 2048 as we passed to OpenAI is too large for training
width, height = main_image.size
if 1800 <= max(width, height) <= 2200:
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
input_ids = inputs["input_ids"][0]
# All columns will participate in attention fully
attention_mask = np.ones_like(input_ids)
# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
def batch_prepare_data_for_qwen2_inference(batch, processor):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["input_prompt_image_base64"])):
example = {
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
"input_prompt_text": batch["input_prompt_text"][i],
"raw_page_text": batch["raw_page_text"][i],
}
processed_example = prepare_data_for_qwen2_inference(example, processor)
processed_examples.append(processed_example)
return {
"input_ids": [x["input_ids"] for x in processed_examples],
"attention_mask": [x["attention_mask"] for x in processed_examples],
"pixel_values": [x["pixel_values"] for x in processed_examples],
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
# Define a custom data collator
class DataCollatorForVisionLanguageModeling:
def __init__(self, processor):
self.processor = processor
def __call__(self, features):
input_ids = [f['input_ids'] for f in features]
attention_mask = [f['attention_mask'] for f in features]
labels = [f['labels'] for f in features]
pixel_values = [f['pixel_values'] for f in features]
# Pad input_ids, attention_mask, labels
batch = self.processor.pad(
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
return_tensors="pt",
padding=True,
)
# Stack pixel_values
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
return batch

View File

@ -0,0 +1,120 @@
import os
import json
import html
import argparse
import boto3
import tempfile
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
from jinja2 import Template
import smart_open
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import markdown2
from pdelfin.s3_utils import get_s3_bytes
from pdelfin.data.renderpdf import render_pdf_to_base64webp
def read_jsonl(path):
with smart_open.smart_open(path, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
def parse_s3_path(path):
# s3://bucket_name/key_name
path = path[5:] # Remove 's3://'
bucket_name, key_name = path.split('/', 1)
return bucket_name, key_name
def generate_presigned_url(s3_client, bucket_name, key_name):
try:
response = s3_client.generate_presigned_url('get_object',
Params={'Bucket': bucket_name, 'Key': key_name},
ExpiresIn=3600) # Link expires in 1 hour
return response
except (NoCredentialsError, PartialCredentialsError):
print("Error: AWS credentials not found or incomplete.")
return None
def process_document(data, s3_client, template, output_dir):
id_ = data.get('id')
text = data.get('text', '')
attributes = data.get('attributes', {})
pdf_page_numbers = attributes.get('pdf_page_numbers', [])
metadata = data.get('metadata', {})
source_file = metadata.get('Source-File')
# Generate base64 image of the corresponding PDF page
if source_file and source_file.startswith('s3://'):
local_pdf = tempfile.NamedTemporaryFile("wb+", suffix=".pdf")
local_pdf.write(get_s3_bytes(s3_client, source_file))
local_pdf.flush()
else:
raise ValueError("Expecting s3 files only")
pages = []
for span in pdf_page_numbers:
start_index, end_index, page_num = span
page_text = text[start_index:end_index]
# Detect and convert Markdown to HTML
page_text = html.escape(page_text, quote=True).replace('&lt;br&gt;', '<br>')
page_text = markdown2.markdown(page_text, extras=["tables"])
base64_image = render_pdf_to_base64webp(local_pdf.name, page_num)
pages.append({'page_num': page_num, 'text': page_text, 'image': base64_image})
local_pdf.close()
# Generate pre-signed URL if source_file is an S3 path
s3_link = None
if source_file and source_file.startswith('s3://'):
bucket_name, key_name = parse_s3_path(source_file)
s3_link = generate_presigned_url(s3_client, bucket_name, key_name)
# Render the HTML using the Jinja template
html_content = template.render(id=id_, pages=pages, s3_link=s3_link)
# Write the HTML content to a file
filename = f'{id_}.html'
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(html_content)
def main(jsonl_path, output_dir, template_path):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Load the Jinja template
with open(os.path.join(os.path.dirname(__file__), template_path), 'r', encoding='utf-8') as template_file:
template_content = template_file.read()
template = Template(template_content)
# Initialize S3 client for generating presigned URLs
workspace_session = boto3.Session(profile_name="s2")
s3_client = workspace_session.client("s3")
# Create ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = []
for line in read_jsonl(jsonl_path):
if not line:
continue
data = json.loads(line)
future = executor.submit(process_document, data, s3_client, template, output_dir)
futures.append(future)
for future in tqdm(as_completed(futures), total=len(futures)):
try:
future.result()
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate HTML pages from a JSONL file with pre-signed S3 links.')
parser.add_argument('jsonl_path', help='Path to the JSONL file (local or s3://)')
parser.add_argument('--output_dir', default='dolma_previews', help='Directory to save HTML files')
parser.add_argument('--template_path', default='dolmaviewer_template.html', help='Path to the Jinja2 template file')
args = parser.parse_args()
main(args.jsonl_path, args.output_dir, args.template_path)

View File

@ -0,0 +1,128 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>{{ id }}</title>
<style>
/* CSS styles */
body {
font-family: Arial, sans-serif;
background-color: #f0f0f0;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
}
.document {
background-color: #fff;
padding: 40px;
margin: 20px;
width: 60%;
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
line-height: 1.8;
}
.page-section {
display: flex;
flex-direction: row;
margin-bottom: 20px;
transition: background-color 0.3s ease;
}
.page-section:hover {
background-color: #f5f5f5;
}
.page-section .text {
flex: 2;
padding: 10px;
text-align: justify;
}
.page-section .image {
flex: 1;
padding: 10px;
}
.page-section img {
max-width: 100%;
height: auto;
border: 1px solid #ccc;
}
table {
width: 100%;
border-collapse: collapse; /* Ensures that borders are collapsed to give a clean look */
margin-bottom: 1.5em; /* Adds some space below the table */
}
th, td {
border: 1px solid #ddd; /* 1px solid border for table cells */
padding: 12px 15px; /* Adds some padding for better spacing inside the cells */
text-align: left; /* Aligns text to the left */
vertical-align: top; /* Aligns content to the top of the cell */
font-size: 14px; /* Adjusts font size for readability */
}
th {
background-color: #f4f4f4; /* Light background for table headers */
font-weight: bold; /* Bolds header text */
text-transform: uppercase; /* Makes header text uppercase */
letter-spacing: 0.05em; /* Adds slight spacing between letters for readability */
border-bottom: 2px solid #ccc; /* Slightly thicker bottom border for headers */
}
tr:nth-child(even) {
background-color: #f9f9f9; /* Alternates row background color */
}
tr:hover {
background-color: #f1f1f1; /* Highlights row on hover for better interaction */
}
td img {
max-width: 100%; /* Ensures any images in table cells scale properly */
height: auto;
display: block;
}
table caption {
caption-side: bottom; /* Position caption at the bottom of the table */
text-align: right;
font-size: 12px;
color: #777;
padding: 5px 0;
}
</style>
<script type="text/javascript">
window.MathJax = {
tex: {
inlineMath: [['$', '$'], ['\\(', '\\)']],
displayMath: [['$$', '$$'], ['\\[', '\\]']]
},
options: {
skipHtmlTags: ['script', 'noscript', 'style', 'textarea', 'pre'],
processHtmlClass: 'mathjax-process' // Class name for areas where LaTeX should be processed
}
};
</script>
<script type="text/javascript" id="MathJax-script" async
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>
</head>
<body>
<div class="document">
{% for page in pages %}
<div class="page-section" id="page-{{ page.page_num }}">
<div class="text">{{ page.text|safe }}</div>
{% if page.image %}
<div class="image">
<a href="{{ s3_link }}#page={{ page.page_num }}" target="_blank">
<img src="data:image/webp;base64,{{ page.image }}" alt="Page {{ page.page_num }} Image">
</a>
</div>
{% endif %}
</div>
{% endfor %}
</div>
</body>
</html>

View File

@ -28,7 +28,8 @@ dependencies = [
"Pillow",
"ftfy",
"bleach",
"duckdb",
"markdown2",
"filelock",
]
license = {file = "LICENSE"}

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -2,11 +2,12 @@ import unittest
import os
import json
import io
import glob
from pypdf import PdfReader
from pdelfin.prompts.anchor import _pdf_report, _linearize_pdf_report, get_anchor_text
from pdelfin.data.renderpdf import get_pdf_media_box_width_height
class AnchorTest(unittest.TestCase):
def testExtractText(self):
@ -102,7 +103,14 @@ class AnchorTest(unittest.TestCase):
print(len(anchor_text))
self.assertLess(len(anchor_text), 4000)
def testTobaccoPaperMissingParagraphs(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 4000)
class BuildSilverTest(unittest.TestCase):
@ -124,4 +132,17 @@ class BuildSilverTest(unittest.TestCase):
print(width, height)
assert max(width, height) == 2048
assert max(width, height) == 2048
class TestRenderPdf(unittest.TestCase):
def testFastMediaBoxMatchesPyPdf(self):
for file in glob.glob(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "*.pdf")):
reader = PdfReader(file)
print("checking", file)
for page_num in range(1, len(reader.pages) + 1):
w1, h1 = get_pdf_media_box_width_height(file, page_num)
pypdfpage = reader.pages[page_num - 1]
self.assertEqual(w1, pypdfpage.mediabox.width)
self.assertEqual(h1, pypdfpage.mediabox.height)

View File

@ -6,14 +6,13 @@ from functools import partial
from transformers import AutoProcessor
from pdelfin.train.dataloader import (
build_batch_query_response_vision_dataset,
extract_openai_batch_query,
build_finetuning_dataset,
extract_openai_batch_response,
load_jsonl_into_ds,
list_dataset_files
)
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
class TestBatchQueryResponseDataset(unittest.TestCase):
@ -24,52 +23,35 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
print(ds)
print(ds["train"])
def testCombinedQueryResponse(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
def testFinetuningDS(self):
ds = build_finetuning_dataset(
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
print(ds)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
print(ds[0])
def testLocalDS(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl",
response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json",
)
print(ds)
ds.to_parquet("/root/trainds_parquet/bigds.parquet")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
print(ds[0])
def testPlotSequenceLengthHistogram(self):
import plotly.express as px
ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
ds = build_finetuning_dataset(
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
initial_len = len(ds)
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200))
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False)
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
max_seen_len = 0
steps = 0
@ -98,43 +80,3 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
)
fig.write_image("sequence_lengths_histogram.png")
def testExtractBatch(self):
query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
print(query_data)
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
def testExtractResponse(self):
response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
print(response_data)
print(response_data[0])
def testPyArrowDirectJson(self):
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl"
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json"
all_files = list_dataset_files(query_glob_path)
import pyarrow as pa
import pyarrow.json as paj
import pyarrow.compute as pc
import pyarrow.fs as fs
s3 = fs.S3FileSystem()
block_size = 200 * 1024**2
for file in all_files:
with s3.open_input_stream(file.replace("s3://", "")) as f:
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
print(f"file {file} rows {table.num_rows}")
print(table.schema)