diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 7c8045c..52794ee 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -449,7 +449,14 @@ def build_dolma_document(pdf_orig_path, page_results): "added": datetime.datetime.now().strftime("%Y-%m-%d"), "created": datetime.datetime.now().strftime("%Y-%m-%d"), "metadata": metadata, - "attributes": {"pdf_page_numbers": pdf_page_spans}, + "attributes": { + "pdf_page_numbers": pdf_page_spans, + "primary_language": [p.response.primary_language for p in page_results], + "is_rotation_valid": [p.response.is_rotation_valid for p in page_results], + "rotation_correction": [p.response.rotation_correction for p in page_results], + "is_table": [p.response.is_table for p in page_results], + "is_diagram": [p.response.is_diagram for p in page_results], + }, } return dolma_doc diff --git a/olmocr/train/prepare_workspace.py b/olmocr/train/prepare_workspace.py new file mode 100755 index 0000000..d73e32e --- /dev/null +++ b/olmocr/train/prepare_workspace.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python3 +""" +Prepare workspace generated by olmocr/pipeline.py for fine-tuning. + +This script reads JSONL files from workspace/results, extracts individual pages +from PDFs based on page boundaries, and creates corresponding markdown files. + +Usage: + python prepare_workspace.py workspace_path output_dir [--max-examples N] +""" + +import argparse +import json +import logging +import os +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from urllib.parse import urlparse + +import boto3 +from pypdf import PdfReader, PdfWriter +from tqdm import tqdm + +from olmocr.s3_utils import parse_s3_path + +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def fetch_s3_file(s3_url: str, local_path: str) -> str: + """Download a file from an S3 URI (s3://bucket/key) to local_path.""" + parsed = urlparse(s3_url) + bucket_name = parsed.netloc + key = parsed.path.lstrip("/") + + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + s3 = boto3.client("s3") + s3.download_file(bucket_name, key, local_path) + return local_path + + +def list_s3_result_files(s3_client, workspace_path: str) -> List[str]: + """List all JSONL files in the S3 workspace results directory.""" + bucket, prefix = parse_s3_path(workspace_path) + results_prefix = os.path.join(prefix, "results").rstrip("/") + "/" + + all_files = [] + paginator = s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=bucket, Prefix=results_prefix): + if "Contents" in page: + all_files.extend([ + f"s3://{bucket}/{obj['Key']}" + for obj in page["Contents"] + if obj["Key"].endswith(".jsonl") + ]) + + logger.info(f"Found {len(all_files)} JSONL files in S3 workspace") + return all_files + + +def download_s3_file(s3_client, s3_path: str) -> str: + """Download an S3 file and return its contents as a string.""" + bucket, key = parse_s3_path(s3_path) + response = s3_client.get_object(Bucket=bucket, Key=key) + return response['Body'].read().decode('utf-8') + + +def load_jsonl_files(results_dir: Path) -> List[Path]: + """Load all JSONL files from the workspace results directory.""" + jsonl_files = list(results_dir.glob("*.jsonl")) + if not jsonl_files: + logger.error(f"No JSONL files found in {results_dir}") + return [] + + logger.info(f"Found {len(jsonl_files)} JSONL files in {results_dir}") + return jsonl_files + + +def parse_jsonl_entry(entry: Dict) -> Optional[Dict]: + """Parse a single JSONL entry and extract relevant information.""" + try: + text = entry.get("text", "") + metadata = entry.get("metadata", {}) + attributes = entry.get("attributes", {}) + + source_file = metadata.get("Source-File", "") + if not source_file: + logger.warning("Entry missing Source-File in metadata") + return None + + pdf_page_numbers = attributes.get("pdf_page_numbers", []) + if not pdf_page_numbers: + logger.warning(f"Entry for {source_file} missing pdf_page_numbers") + return None + + return { + "id": entry.get("id", ""), + "text": text, + "source_file": source_file, + "metadata": metadata, + "pdf_page_numbers": pdf_page_numbers + } + except Exception as e: + logger.error(f"Error parsing JSONL entry: {e}") + return None + + +def extract_page_text(text: str, page_boundaries: List[List[int]]) -> Dict[int, str]: + """ + Extract text for each page based on character boundaries. + + Args: + text: Full document text + page_boundaries: List of [start_char, end_char, page_num] for each page + + Returns: + Dictionary mapping page number to extracted text + """ + page_texts = {} + + for start_char, end_char, page_num in page_boundaries: + page_text = text[start_char:end_char] + page_texts[page_num] = page_text + + return page_texts + + +def extract_pdf_page(pdf_path: str, page_num: int, output_path: str) -> bool: + """ + Extract a single page from a PDF and save it to output_path. + + Args: + pdf_path: Path to the source PDF + page_num: 1-based page number to extract + output_path: Path where the single-page PDF will be saved + + Returns: + True if successful, False otherwise + """ + try: + reader = PdfReader(pdf_path) + + # Check if page number is valid + if page_num < 1 or page_num > len(reader.pages): + logger.error(f"Page {page_num} out of range for {pdf_path} (has {len(reader.pages)} pages)") + return False + + writer = PdfWriter() + # PyPDF uses 0-based indexing + writer.add_page(reader.pages[page_num - 1]) + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "wb") as f: + writer.write(f) + + return True + except Exception as e: + logger.error(f"Error extracting page {page_num} from {pdf_path}: {e}") + return False + + +def process_document(entry_data: Dict, output_dir: Path, cache_dir: Path) -> Tuple[int, int]: + """ + Process a single document: extract pages and create markdown files. + + Returns: + Tuple of (successful_pages, failed_pages) + """ + successful = 0 + failed = 0 + + source_file = entry_data["source_file"] + doc_id = entry_data["id"] + full_text = entry_data["text"] + pdf_page_numbers = entry_data["pdf_page_numbers"] + + # Extract page texts + page_texts = extract_page_text(full_text, pdf_page_numbers) + + # Download PDF if it's from S3 + if source_file.startswith("s3://"): + # Create a cache path based on the S3 key + parsed = urlparse(source_file) + cache_path = cache_dir / parsed.netloc / parsed.path.lstrip("/") + local_pdf_path = str(cache_path) + + if not cache_path.exists(): + try: + logger.info(f"Downloading {source_file} to cache") + fetch_s3_file(source_file, local_pdf_path) + except Exception as e: + logger.error(f"Failed to download {source_file}: {e}") + return 0, len(page_texts) + else: + logger.debug(f"Using cached PDF: {cache_path}") + else: + local_pdf_path = source_file + + # Create output subdirectory based on document ID (first 4 characters) + if len(doc_id) >= 4: + subdir = doc_id[:4] + doc_dir = output_dir / subdir + else: + doc_dir = output_dir / "misc" + + doc_dir.mkdir(parents=True, exist_ok=True) + + # Process each page + for page_num, page_text in page_texts.items(): + try: + # Create filenames + base_name = f"{doc_id}_page{page_num}" + md_path = doc_dir / f"{base_name}.md" + pdf_path = doc_dir / f"{base_name}.pdf" + + # Write markdown file + with open(md_path, "w", encoding="utf-8") as f: + # Write YAML front matter + f.write("---\n") + f.write(f"page_number: {page_num}\n") + f.write(f"source_file: {source_file}\n") + f.write(f"document_id: {doc_id}\n") + for k, v in entry_data["metadata"].items(): + if k != "Source-File": # Already included as source_file + f.write(f"{k}: {v}\n") + f.write("---\n\n") + + # Write page text + f.write(page_text) + + # Extract PDF page + if extract_pdf_page(local_pdf_path, page_num, str(pdf_path)): + successful += 1 + logger.debug(f"Created {md_path} and {pdf_path}") + else: + failed += 1 + # Remove the markdown file if PDF extraction failed + os.remove(md_path) + + except Exception as e: + logger.error(f"Error processing page {page_num} of document {doc_id}: {e}") + failed += 1 + + return successful, failed + + +def process_workspace(workspace_path: str, output_dir: Path, max_examples: Optional[int] = None) -> None: + """ + Process all JSONL files in the workspace and create training data. + + Args: + workspace_path: Path to the workspace directory (local or S3) + output_dir: Path to the output directory for training data + max_examples: Maximum number of documents to process (None for all) + """ + # Create output and cache directories + output_dir.mkdir(parents=True, exist_ok=True) + cache_dir = output_dir / ".pdf_cache" + cache_dir.mkdir(exist_ok=True) + + # Initialize S3 client if workspace is on S3 + s3_client = None + if workspace_path.startswith("s3://"): + s3_client = boto3.client("s3") + + # Parse all entries + all_entries = [] + + if workspace_path.startswith("s3://"): + # S3 workspace + jsonl_files = list_s3_result_files(s3_client, workspace_path) + if not jsonl_files: + logger.error("No JSONL files found in S3 workspace") + sys.exit(1) + + for s3_file in jsonl_files: + logger.info(f"Reading {s3_file}...") + try: + content = download_s3_file(s3_client, s3_file) + for line in content.splitlines(): + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + parsed_entry = parse_jsonl_entry(entry) + if parsed_entry: + all_entries.append(parsed_entry) + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in {s3_file}: {e}") + except Exception as e: + logger.error(f"Error reading {s3_file}: {e}") + else: + # Local workspace + workspace_path_obj = Path(workspace_path) + results_dir = workspace_path_obj / "results" + if not results_dir.exists(): + logger.error(f"Results directory not found: {results_dir}") + sys.exit(1) + + jsonl_files = load_jsonl_files(results_dir) + if not jsonl_files: + sys.exit(1) + + for jsonl_file in jsonl_files: + logger.info(f"Reading {jsonl_file.name}...") + with open(jsonl_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + parsed_entry = parse_jsonl_entry(entry) + if parsed_entry: + all_entries.append(parsed_entry) + except json.JSONDecodeError as e: + logger.error(f"JSON decode error: {e}") + + logger.info(f"Found {len(all_entries)} valid documents to process") + + # Limit entries if max_examples is set + if max_examples and len(all_entries) > max_examples: + all_entries = all_entries[:max_examples] + logger.info(f"Limited to {max_examples} documents") + + # Process documents with progress bar + total_successful = 0 + total_failed = 0 + + with tqdm(total=len(all_entries), desc="Processing documents") as pbar: + for entry_data in all_entries: + successful, failed = process_document(entry_data, output_dir, cache_dir) + total_successful += successful + total_failed += failed + pbar.update(1) + pbar.set_postfix({ + "pages_ok": total_successful, + "pages_failed": total_failed + }) + + # Print summary + logger.info("\nProcessing complete!") + logger.info(f"Successfully processed: {total_successful} pages") + logger.info(f"Failed: {total_failed} pages") + logger.info(f"Output directory: {output_dir.absolute()}") + + +def main(): + parser = argparse.ArgumentParser( + description="Prepare workspace data for fine-tuning by extracting individual pages" + ) + parser.add_argument( + "workspace_path", + type=str, + help="Path to the workspace directory containing results folder" + ) + parser.add_argument( + "output_dir", + type=str, + help="Output directory for processed training data" + ) + parser.add_argument( + "--max-examples", + type=int, + default=None, + help="Maximum number of documents to process (default: all)" + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging" + ) + + args = parser.parse_args() + + if args.debug: + logger.setLevel(logging.DEBUG) + + workspace_path = args.workspace_path + output_dir = Path(args.output_dir) + + # Check if workspace exists + if workspace_path.startswith("s3://"): + # For S3, we'll check existence when listing files + logger.info(f"Using S3 workspace: {workspace_path}") + else: + workspace_path_obj = Path(workspace_path) + if not workspace_path_obj.exists(): + logger.error(f"Workspace path does not exist: {workspace_path}") + sys.exit(1) + + process_workspace(workspace_path, output_dir, args.max_examples) + + +if __name__ == "__main__": + main() \ No newline at end of file