mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 19:13:53 +00:00
Saving extra metadata that will be useful for finetuning
This commit is contained in:
parent
7c098955a9
commit
d6591c04a1
@ -449,7 +449,14 @@ def build_dolma_document(pdf_orig_path, page_results):
|
||||
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"metadata": metadata,
|
||||
"attributes": {"pdf_page_numbers": pdf_page_spans},
|
||||
"attributes": {
|
||||
"pdf_page_numbers": pdf_page_spans,
|
||||
"primary_language": [p.response.primary_language for p in page_results],
|
||||
"is_rotation_valid": [p.response.is_rotation_valid for p in page_results],
|
||||
"rotation_correction": [p.response.rotation_correction for p in page_results],
|
||||
"is_table": [p.response.is_table for p in page_results],
|
||||
"is_diagram": [p.response.is_diagram for p in page_results],
|
||||
},
|
||||
}
|
||||
return dolma_doc
|
||||
|
||||
|
||||
405
olmocr/train/prepare_workspace.py
Executable file
405
olmocr/train/prepare_workspace.py
Executable file
@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prepare workspace generated by olmocr/pipeline.py for fine-tuning.
|
||||
|
||||
This script reads JSONL files from workspace/results, extracts individual pages
|
||||
from PDFs based on page boundaries, and creates corresponding markdown files.
|
||||
|
||||
Usage:
|
||||
python prepare_workspace.py workspace_path output_dir [--max-examples N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fetch_s3_file(s3_url: str, local_path: str) -> str:
|
||||
"""Download a file from an S3 URI (s3://bucket/key) to local_path."""
|
||||
parsed = urlparse(s3_url)
|
||||
bucket_name = parsed.netloc
|
||||
key = parsed.path.lstrip("/")
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
s3 = boto3.client("s3")
|
||||
s3.download_file(bucket_name, key, local_path)
|
||||
return local_path
|
||||
|
||||
|
||||
def list_s3_result_files(s3_client, workspace_path: str) -> List[str]:
|
||||
"""List all JSONL files in the S3 workspace results directory."""
|
||||
bucket, prefix = parse_s3_path(workspace_path)
|
||||
results_prefix = os.path.join(prefix, "results").rstrip("/") + "/"
|
||||
|
||||
all_files = []
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
for page in paginator.paginate(Bucket=bucket, Prefix=results_prefix):
|
||||
if "Contents" in page:
|
||||
all_files.extend([
|
||||
f"s3://{bucket}/{obj['Key']}"
|
||||
for obj in page["Contents"]
|
||||
if obj["Key"].endswith(".jsonl")
|
||||
])
|
||||
|
||||
logger.info(f"Found {len(all_files)} JSONL files in S3 workspace")
|
||||
return all_files
|
||||
|
||||
|
||||
def download_s3_file(s3_client, s3_path: str) -> str:
|
||||
"""Download an S3 file and return its contents as a string."""
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
return response['Body'].read().decode('utf-8')
|
||||
|
||||
|
||||
def load_jsonl_files(results_dir: Path) -> List[Path]:
|
||||
"""Load all JSONL files from the workspace results directory."""
|
||||
jsonl_files = list(results_dir.glob("*.jsonl"))
|
||||
if not jsonl_files:
|
||||
logger.error(f"No JSONL files found in {results_dir}")
|
||||
return []
|
||||
|
||||
logger.info(f"Found {len(jsonl_files)} JSONL files in {results_dir}")
|
||||
return jsonl_files
|
||||
|
||||
|
||||
def parse_jsonl_entry(entry: Dict) -> Optional[Dict]:
|
||||
"""Parse a single JSONL entry and extract relevant information."""
|
||||
try:
|
||||
text = entry.get("text", "")
|
||||
metadata = entry.get("metadata", {})
|
||||
attributes = entry.get("attributes", {})
|
||||
|
||||
source_file = metadata.get("Source-File", "")
|
||||
if not source_file:
|
||||
logger.warning("Entry missing Source-File in metadata")
|
||||
return None
|
||||
|
||||
pdf_page_numbers = attributes.get("pdf_page_numbers", [])
|
||||
if not pdf_page_numbers:
|
||||
logger.warning(f"Entry for {source_file} missing pdf_page_numbers")
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": entry.get("id", ""),
|
||||
"text": text,
|
||||
"source_file": source_file,
|
||||
"metadata": metadata,
|
||||
"pdf_page_numbers": pdf_page_numbers
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing JSONL entry: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_page_text(text: str, page_boundaries: List[List[int]]) -> Dict[int, str]:
|
||||
"""
|
||||
Extract text for each page based on character boundaries.
|
||||
|
||||
Args:
|
||||
text: Full document text
|
||||
page_boundaries: List of [start_char, end_char, page_num] for each page
|
||||
|
||||
Returns:
|
||||
Dictionary mapping page number to extracted text
|
||||
"""
|
||||
page_texts = {}
|
||||
|
||||
for start_char, end_char, page_num in page_boundaries:
|
||||
page_text = text[start_char:end_char]
|
||||
page_texts[page_num] = page_text
|
||||
|
||||
return page_texts
|
||||
|
||||
|
||||
def extract_pdf_page(pdf_path: str, page_num: int, output_path: str) -> bool:
|
||||
"""
|
||||
Extract a single page from a PDF and save it to output_path.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the source PDF
|
||||
page_num: 1-based page number to extract
|
||||
output_path: Path where the single-page PDF will be saved
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
reader = PdfReader(pdf_path)
|
||||
|
||||
# Check if page number is valid
|
||||
if page_num < 1 or page_num > len(reader.pages):
|
||||
logger.error(f"Page {page_num} out of range for {pdf_path} (has {len(reader.pages)} pages)")
|
||||
return False
|
||||
|
||||
writer = PdfWriter()
|
||||
# PyPDF uses 0-based indexing
|
||||
writer.add_page(reader.pages[page_num - 1])
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
writer.write(f)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting page {page_num} from {pdf_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def process_document(entry_data: Dict, output_dir: Path, cache_dir: Path) -> Tuple[int, int]:
|
||||
"""
|
||||
Process a single document: extract pages and create markdown files.
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_pages, failed_pages)
|
||||
"""
|
||||
successful = 0
|
||||
failed = 0
|
||||
|
||||
source_file = entry_data["source_file"]
|
||||
doc_id = entry_data["id"]
|
||||
full_text = entry_data["text"]
|
||||
pdf_page_numbers = entry_data["pdf_page_numbers"]
|
||||
|
||||
# Extract page texts
|
||||
page_texts = extract_page_text(full_text, pdf_page_numbers)
|
||||
|
||||
# Download PDF if it's from S3
|
||||
if source_file.startswith("s3://"):
|
||||
# Create a cache path based on the S3 key
|
||||
parsed = urlparse(source_file)
|
||||
cache_path = cache_dir / parsed.netloc / parsed.path.lstrip("/")
|
||||
local_pdf_path = str(cache_path)
|
||||
|
||||
if not cache_path.exists():
|
||||
try:
|
||||
logger.info(f"Downloading {source_file} to cache")
|
||||
fetch_s3_file(source_file, local_pdf_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {source_file}: {e}")
|
||||
return 0, len(page_texts)
|
||||
else:
|
||||
logger.debug(f"Using cached PDF: {cache_path}")
|
||||
else:
|
||||
local_pdf_path = source_file
|
||||
|
||||
# Create output subdirectory based on document ID (first 4 characters)
|
||||
if len(doc_id) >= 4:
|
||||
subdir = doc_id[:4]
|
||||
doc_dir = output_dir / subdir
|
||||
else:
|
||||
doc_dir = output_dir / "misc"
|
||||
|
||||
doc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process each page
|
||||
for page_num, page_text in page_texts.items():
|
||||
try:
|
||||
# Create filenames
|
||||
base_name = f"{doc_id}_page{page_num}"
|
||||
md_path = doc_dir / f"{base_name}.md"
|
||||
pdf_path = doc_dir / f"{base_name}.pdf"
|
||||
|
||||
# Write markdown file
|
||||
with open(md_path, "w", encoding="utf-8") as f:
|
||||
# Write YAML front matter
|
||||
f.write("---\n")
|
||||
f.write(f"page_number: {page_num}\n")
|
||||
f.write(f"source_file: {source_file}\n")
|
||||
f.write(f"document_id: {doc_id}\n")
|
||||
for k, v in entry_data["metadata"].items():
|
||||
if k != "Source-File": # Already included as source_file
|
||||
f.write(f"{k}: {v}\n")
|
||||
f.write("---\n\n")
|
||||
|
||||
# Write page text
|
||||
f.write(page_text)
|
||||
|
||||
# Extract PDF page
|
||||
if extract_pdf_page(local_pdf_path, page_num, str(pdf_path)):
|
||||
successful += 1
|
||||
logger.debug(f"Created {md_path} and {pdf_path}")
|
||||
else:
|
||||
failed += 1
|
||||
# Remove the markdown file if PDF extraction failed
|
||||
os.remove(md_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing page {page_num} of document {doc_id}: {e}")
|
||||
failed += 1
|
||||
|
||||
return successful, failed
|
||||
|
||||
|
||||
def process_workspace(workspace_path: str, output_dir: Path, max_examples: Optional[int] = None) -> None:
|
||||
"""
|
||||
Process all JSONL files in the workspace and create training data.
|
||||
|
||||
Args:
|
||||
workspace_path: Path to the workspace directory (local or S3)
|
||||
output_dir: Path to the output directory for training data
|
||||
max_examples: Maximum number of documents to process (None for all)
|
||||
"""
|
||||
# Create output and cache directories
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_dir = output_dir / ".pdf_cache"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Initialize S3 client if workspace is on S3
|
||||
s3_client = None
|
||||
if workspace_path.startswith("s3://"):
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
# Parse all entries
|
||||
all_entries = []
|
||||
|
||||
if workspace_path.startswith("s3://"):
|
||||
# S3 workspace
|
||||
jsonl_files = list_s3_result_files(s3_client, workspace_path)
|
||||
if not jsonl_files:
|
||||
logger.error("No JSONL files found in S3 workspace")
|
||||
sys.exit(1)
|
||||
|
||||
for s3_file in jsonl_files:
|
||||
logger.info(f"Reading {s3_file}...")
|
||||
try:
|
||||
content = download_s3_file(s3_client, s3_file)
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
parsed_entry = parse_jsonl_entry(entry)
|
||||
if parsed_entry:
|
||||
all_entries.append(parsed_entry)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error in {s3_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {s3_file}: {e}")
|
||||
else:
|
||||
# Local workspace
|
||||
workspace_path_obj = Path(workspace_path)
|
||||
results_dir = workspace_path_obj / "results"
|
||||
if not results_dir.exists():
|
||||
logger.error(f"Results directory not found: {results_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
jsonl_files = load_jsonl_files(results_dir)
|
||||
if not jsonl_files:
|
||||
sys.exit(1)
|
||||
|
||||
for jsonl_file in jsonl_files:
|
||||
logger.info(f"Reading {jsonl_file.name}...")
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
parsed_entry = parse_jsonl_entry(entry)
|
||||
if parsed_entry:
|
||||
all_entries.append(parsed_entry)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error: {e}")
|
||||
|
||||
logger.info(f"Found {len(all_entries)} valid documents to process")
|
||||
|
||||
# Limit entries if max_examples is set
|
||||
if max_examples and len(all_entries) > max_examples:
|
||||
all_entries = all_entries[:max_examples]
|
||||
logger.info(f"Limited to {max_examples} documents")
|
||||
|
||||
# Process documents with progress bar
|
||||
total_successful = 0
|
||||
total_failed = 0
|
||||
|
||||
with tqdm(total=len(all_entries), desc="Processing documents") as pbar:
|
||||
for entry_data in all_entries:
|
||||
successful, failed = process_document(entry_data, output_dir, cache_dir)
|
||||
total_successful += successful
|
||||
total_failed += failed
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({
|
||||
"pages_ok": total_successful,
|
||||
"pages_failed": total_failed
|
||||
})
|
||||
|
||||
# Print summary
|
||||
logger.info("\nProcessing complete!")
|
||||
logger.info(f"Successfully processed: {total_successful} pages")
|
||||
logger.info(f"Failed: {total_failed} pages")
|
||||
logger.info(f"Output directory: {output_dir.absolute()}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare workspace data for fine-tuning by extracting individual pages"
|
||||
)
|
||||
parser.add_argument(
|
||||
"workspace_path",
|
||||
type=str,
|
||||
help="Path to the workspace directory containing results folder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
type=str,
|
||||
help="Output directory for processed training data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-examples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of documents to process (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
workspace_path = args.workspace_path
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
# Check if workspace exists
|
||||
if workspace_path.startswith("s3://"):
|
||||
# For S3, we'll check existence when listing files
|
||||
logger.info(f"Using S3 workspace: {workspace_path}")
|
||||
else:
|
||||
workspace_path_obj = Path(workspace_path)
|
||||
if not workspace_path_obj.exists():
|
||||
logger.error(f"Workspace path does not exist: {workspace_path}")
|
||||
sys.exit(1)
|
||||
|
||||
process_workspace(workspace_path, output_dir, args.max_examples)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user