Saving extra metadata that will be useful for finetuning

This commit is contained in:
Jake Poznanski 2025-08-04 20:01:30 +00:00
parent 7c098955a9
commit d6591c04a1
2 changed files with 413 additions and 1 deletions

View File

@ -449,7 +449,14 @@ def build_dolma_document(pdf_orig_path, page_results):
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {"pdf_page_numbers": pdf_page_spans},
"attributes": {
"pdf_page_numbers": pdf_page_spans,
"primary_language": [p.response.primary_language for p in page_results],
"is_rotation_valid": [p.response.is_rotation_valid for p in page_results],
"rotation_correction": [p.response.rotation_correction for p in page_results],
"is_table": [p.response.is_table for p in page_results],
"is_diagram": [p.response.is_diagram for p in page_results],
},
}
return dolma_doc

405
olmocr/train/prepare_workspace.py Executable file
View File

@ -0,0 +1,405 @@
#!/usr/bin/env python3
"""
Prepare workspace generated by olmocr/pipeline.py for fine-tuning.
This script reads JSONL files from workspace/results, extracts individual pages
from PDFs based on page boundaries, and creates corresponding markdown files.
Usage:
python prepare_workspace.py workspace_path output_dir [--max-examples N]
"""
import argparse
import json
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse
import boto3
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm
from olmocr.s3_utils import parse_s3_path
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def fetch_s3_file(s3_url: str, local_path: str) -> str:
"""Download a file from an S3 URI (s3://bucket/key) to local_path."""
parsed = urlparse(s3_url)
bucket_name = parsed.netloc
key = parsed.path.lstrip("/")
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(local_path), exist_ok=True)
s3 = boto3.client("s3")
s3.download_file(bucket_name, key, local_path)
return local_path
def list_s3_result_files(s3_client, workspace_path: str) -> List[str]:
"""List all JSONL files in the S3 workspace results directory."""
bucket, prefix = parse_s3_path(workspace_path)
results_prefix = os.path.join(prefix, "results").rstrip("/") + "/"
all_files = []
paginator = s3_client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=results_prefix):
if "Contents" in page:
all_files.extend([
f"s3://{bucket}/{obj['Key']}"
for obj in page["Contents"]
if obj["Key"].endswith(".jsonl")
])
logger.info(f"Found {len(all_files)} JSONL files in S3 workspace")
return all_files
def download_s3_file(s3_client, s3_path: str) -> str:
"""Download an S3 file and return its contents as a string."""
bucket, key = parse_s3_path(s3_path)
response = s3_client.get_object(Bucket=bucket, Key=key)
return response['Body'].read().decode('utf-8')
def load_jsonl_files(results_dir: Path) -> List[Path]:
"""Load all JSONL files from the workspace results directory."""
jsonl_files = list(results_dir.glob("*.jsonl"))
if not jsonl_files:
logger.error(f"No JSONL files found in {results_dir}")
return []
logger.info(f"Found {len(jsonl_files)} JSONL files in {results_dir}")
return jsonl_files
def parse_jsonl_entry(entry: Dict) -> Optional[Dict]:
"""Parse a single JSONL entry and extract relevant information."""
try:
text = entry.get("text", "")
metadata = entry.get("metadata", {})
attributes = entry.get("attributes", {})
source_file = metadata.get("Source-File", "")
if not source_file:
logger.warning("Entry missing Source-File in metadata")
return None
pdf_page_numbers = attributes.get("pdf_page_numbers", [])
if not pdf_page_numbers:
logger.warning(f"Entry for {source_file} missing pdf_page_numbers")
return None
return {
"id": entry.get("id", ""),
"text": text,
"source_file": source_file,
"metadata": metadata,
"pdf_page_numbers": pdf_page_numbers
}
except Exception as e:
logger.error(f"Error parsing JSONL entry: {e}")
return None
def extract_page_text(text: str, page_boundaries: List[List[int]]) -> Dict[int, str]:
"""
Extract text for each page based on character boundaries.
Args:
text: Full document text
page_boundaries: List of [start_char, end_char, page_num] for each page
Returns:
Dictionary mapping page number to extracted text
"""
page_texts = {}
for start_char, end_char, page_num in page_boundaries:
page_text = text[start_char:end_char]
page_texts[page_num] = page_text
return page_texts
def extract_pdf_page(pdf_path: str, page_num: int, output_path: str) -> bool:
"""
Extract a single page from a PDF and save it to output_path.
Args:
pdf_path: Path to the source PDF
page_num: 1-based page number to extract
output_path: Path where the single-page PDF will be saved
Returns:
True if successful, False otherwise
"""
try:
reader = PdfReader(pdf_path)
# Check if page number is valid
if page_num < 1 or page_num > len(reader.pages):
logger.error(f"Page {page_num} out of range for {pdf_path} (has {len(reader.pages)} pages)")
return False
writer = PdfWriter()
# PyPDF uses 0-based indexing
writer.add_page(reader.pages[page_num - 1])
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "wb") as f:
writer.write(f)
return True
except Exception as e:
logger.error(f"Error extracting page {page_num} from {pdf_path}: {e}")
return False
def process_document(entry_data: Dict, output_dir: Path, cache_dir: Path) -> Tuple[int, int]:
"""
Process a single document: extract pages and create markdown files.
Returns:
Tuple of (successful_pages, failed_pages)
"""
successful = 0
failed = 0
source_file = entry_data["source_file"]
doc_id = entry_data["id"]
full_text = entry_data["text"]
pdf_page_numbers = entry_data["pdf_page_numbers"]
# Extract page texts
page_texts = extract_page_text(full_text, pdf_page_numbers)
# Download PDF if it's from S3
if source_file.startswith("s3://"):
# Create a cache path based on the S3 key
parsed = urlparse(source_file)
cache_path = cache_dir / parsed.netloc / parsed.path.lstrip("/")
local_pdf_path = str(cache_path)
if not cache_path.exists():
try:
logger.info(f"Downloading {source_file} to cache")
fetch_s3_file(source_file, local_pdf_path)
except Exception as e:
logger.error(f"Failed to download {source_file}: {e}")
return 0, len(page_texts)
else:
logger.debug(f"Using cached PDF: {cache_path}")
else:
local_pdf_path = source_file
# Create output subdirectory based on document ID (first 4 characters)
if len(doc_id) >= 4:
subdir = doc_id[:4]
doc_dir = output_dir / subdir
else:
doc_dir = output_dir / "misc"
doc_dir.mkdir(parents=True, exist_ok=True)
# Process each page
for page_num, page_text in page_texts.items():
try:
# Create filenames
base_name = f"{doc_id}_page{page_num}"
md_path = doc_dir / f"{base_name}.md"
pdf_path = doc_dir / f"{base_name}.pdf"
# Write markdown file
with open(md_path, "w", encoding="utf-8") as f:
# Write YAML front matter
f.write("---\n")
f.write(f"page_number: {page_num}\n")
f.write(f"source_file: {source_file}\n")
f.write(f"document_id: {doc_id}\n")
for k, v in entry_data["metadata"].items():
if k != "Source-File": # Already included as source_file
f.write(f"{k}: {v}\n")
f.write("---\n\n")
# Write page text
f.write(page_text)
# Extract PDF page
if extract_pdf_page(local_pdf_path, page_num, str(pdf_path)):
successful += 1
logger.debug(f"Created {md_path} and {pdf_path}")
else:
failed += 1
# Remove the markdown file if PDF extraction failed
os.remove(md_path)
except Exception as e:
logger.error(f"Error processing page {page_num} of document {doc_id}: {e}")
failed += 1
return successful, failed
def process_workspace(workspace_path: str, output_dir: Path, max_examples: Optional[int] = None) -> None:
"""
Process all JSONL files in the workspace and create training data.
Args:
workspace_path: Path to the workspace directory (local or S3)
output_dir: Path to the output directory for training data
max_examples: Maximum number of documents to process (None for all)
"""
# Create output and cache directories
output_dir.mkdir(parents=True, exist_ok=True)
cache_dir = output_dir / ".pdf_cache"
cache_dir.mkdir(exist_ok=True)
# Initialize S3 client if workspace is on S3
s3_client = None
if workspace_path.startswith("s3://"):
s3_client = boto3.client("s3")
# Parse all entries
all_entries = []
if workspace_path.startswith("s3://"):
# S3 workspace
jsonl_files = list_s3_result_files(s3_client, workspace_path)
if not jsonl_files:
logger.error("No JSONL files found in S3 workspace")
sys.exit(1)
for s3_file in jsonl_files:
logger.info(f"Reading {s3_file}...")
try:
content = download_s3_file(s3_client, s3_file)
for line in content.splitlines():
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
parsed_entry = parse_jsonl_entry(entry)
if parsed_entry:
all_entries.append(parsed_entry)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in {s3_file}: {e}")
except Exception as e:
logger.error(f"Error reading {s3_file}: {e}")
else:
# Local workspace
workspace_path_obj = Path(workspace_path)
results_dir = workspace_path_obj / "results"
if not results_dir.exists():
logger.error(f"Results directory not found: {results_dir}")
sys.exit(1)
jsonl_files = load_jsonl_files(results_dir)
if not jsonl_files:
sys.exit(1)
for jsonl_file in jsonl_files:
logger.info(f"Reading {jsonl_file.name}...")
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
parsed_entry = parse_jsonl_entry(entry)
if parsed_entry:
all_entries.append(parsed_entry)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
logger.info(f"Found {len(all_entries)} valid documents to process")
# Limit entries if max_examples is set
if max_examples and len(all_entries) > max_examples:
all_entries = all_entries[:max_examples]
logger.info(f"Limited to {max_examples} documents")
# Process documents with progress bar
total_successful = 0
total_failed = 0
with tqdm(total=len(all_entries), desc="Processing documents") as pbar:
for entry_data in all_entries:
successful, failed = process_document(entry_data, output_dir, cache_dir)
total_successful += successful
total_failed += failed
pbar.update(1)
pbar.set_postfix({
"pages_ok": total_successful,
"pages_failed": total_failed
})
# Print summary
logger.info("\nProcessing complete!")
logger.info(f"Successfully processed: {total_successful} pages")
logger.info(f"Failed: {total_failed} pages")
logger.info(f"Output directory: {output_dir.absolute()}")
def main():
parser = argparse.ArgumentParser(
description="Prepare workspace data for fine-tuning by extracting individual pages"
)
parser.add_argument(
"workspace_path",
type=str,
help="Path to the workspace directory containing results folder"
)
parser.add_argument(
"output_dir",
type=str,
help="Output directory for processed training data"
)
parser.add_argument(
"--max-examples",
type=int,
default=None,
help="Maximum number of documents to process (default: all)"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug logging"
)
args = parser.parse_args()
if args.debug:
logger.setLevel(logging.DEBUG)
workspace_path = args.workspace_path
output_dir = Path(args.output_dir)
# Check if workspace exists
if workspace_path.startswith("s3://"):
# For S3, we'll check existence when listing files
logger.info(f"Using S3 workspace: {workspace_path}")
else:
workspace_path_obj = Path(workspace_path)
if not workspace_path_obj.exists():
logger.error(f"Workspace path does not exist: {workspace_path}")
sys.exit(1)
process_workspace(workspace_path, output_dir, args.max_examples)
if __name__ == "__main__":
main()