#!/usr/bin/env python3 # Takes a dataset location in olmocr-mix format, (ex. a nested directory structure folder/subfolder/document.md with a corresponding folder/subfolder/document.pdf) # Then, it will randomly shuffle these (with a fixed seed), and prompt chatgpt to clean up the transcription, and output a cleaned document # Uses structured output to get a good result, then writes things back in the same format in a new root folder, preserving the original folder structure import argparse import json import os import random import sys from pathlib import Path from typing import List, Tuple, Any, Dict from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, as_completed from pypdf import PdfReader from olmocr.data.renderpdf import render_pdf_to_base64png from openai import OpenAI from pydantic import BaseModel, Field from tqdm import tqdm # Structured output model for ChatGPT response class CleanedDocument(BaseModel): cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription") confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0) corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text") @dataclass class DocumentPair: md_path: Path pdf_path: Path relative_path: Path # Relative path from root for preserving structure def parse_args(): parser = argparse.ArgumentParser( description="Clean OCR transcriptions using ChatGPT with visual PDF context" ) parser.add_argument( "input_dir", help="Input directory containing olmocr-mix format data (MD files with corresponding PDFs)" ) parser.add_argument( "output_dir", help="Output directory for cleaned documents (preserves folder structure)" ) parser.add_argument( "--openai-api-key", help="OpenAI API key (can also be set via OPENAI_API_KEY environment variable)", default=os.getenv("OPENAI_API_KEY") ) parser.add_argument( "--model", default="gpt-4o-mini", help="OpenAI model to use (default: gpt-4o-mini)" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for shuffling documents (default: 42)" ) parser.add_argument( "--batch-size", type=int, default=10, help="Number of documents to process in parallel (default: 10)" ) parser.add_argument( "--max-documents", type=int, help="Maximum number of documents to process (useful for testing)" ) parser.add_argument( "--skip-existing", action="store_true", help="Skip documents that already have cleaned versions in the output directory" ) parser.add_argument( "--verbose", action="store_true", help="Enable verbose output" ) return parser.parse_args() def check_single_page_pdf(pdf_path: Path) -> bool: """Check if a PDF has exactly one page.""" try: with open(pdf_path, 'rb') as pdf_file: pdf_reader = PdfReader(pdf_file) return len(pdf_reader.pages) == 1 except Exception as e: print(f"Error checking PDF {pdf_path}: {e}") return False def find_document_pairs(input_dir: Path, verbose: bool = False) -> List[DocumentPair]: """Find all MD files with corresponding single-page PDF files.""" pairs = [] skipped_no_pdf = 0 skipped_multi_page = 0 for md_path in input_dir.rglob("*.md"): # Check for corresponding PDF pdf_path = md_path.with_suffix(".pdf") if not pdf_path.exists(): if verbose: print(f"Warning: No PDF found for {md_path}") skipped_no_pdf += 1 continue # Check if PDF has exactly one page if not check_single_page_pdf(pdf_path): if verbose: print(f"Warning: Skipping multi-page PDF {pdf_path}") skipped_multi_page += 1 continue relative_path = md_path.relative_to(input_dir) pairs.append(DocumentPair(md_path, pdf_path, relative_path)) if skipped_no_pdf > 0 or skipped_multi_page > 0: print(f"Skipped {skipped_no_pdf} files without PDFs and {skipped_multi_page} multi-page PDFs") return pairs def render_single_page_pdf(pdf_path: Path) -> str: """Render a single-page PDF to base64 PNG image.""" try: # Use render_pdf_to_base64png with target_longest_image_dim=2048 base64_png = render_pdf_to_base64png( str(pdf_path), 1, # Always page 1 since we validated it's a single-page PDF target_longest_image_dim=2048 ) return base64_png except Exception as e: raise RuntimeError(f"Could not render PDF {pdf_path}: {e}") def clean_document_with_chatgpt( client: OpenAI, model: str, md_content: str, pdf_image: str, verbose: bool = False ) -> CleanedDocument: """Use ChatGPT to clean the OCR transcription with PDF context.""" # Prepare the messages messages: List[Dict[str, Any]] = [ { "role": "system", "content": ( "You are an expert at cleaning and correcting OCR transcriptions. " "You will be given an OCR transcription and an image of the original PDF page. " "Your task is to:\n" "1. Fix OCR errors and typos\n" "2. Correct formatting issues\n" "3. Restore proper punctuation and capitalization\n" "4. Fix word breaks and line breaks\n" "5. Ensure mathematical formulas and special characters are correct\n" "6. Maintain the semantic structure of the document\n" "Return a cleaned version that accurately represents the original document." ) } ] # Add the content with the PDF image content: List[Dict[str, Any]] = [ { "type": "text", "text": f"Please clean the following OCR transcription based on the provided PDF page image:\n\n{md_content}" }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{pdf_image}" } } ] messages.append({ "role": "user", "content": content }) # Make the API call with structured output try: response = client.beta.chat.completions.parse( model=model, messages=messages, # type: ignore response_format=CleanedDocument, temperature=0.2, # Lower temperature for more consistent cleaning max_tokens=16384 ) parsed_result = response.choices[0].message.parsed if parsed_result is None: raise ValueError("ChatGPT returned no parsed result") return parsed_result except Exception as e: print(f"Error calling ChatGPT: {e}") raise def process_document( doc_pair: DocumentPair, client: OpenAI, model: str, output_dir: Path, skip_existing: bool, verbose: bool ) -> Tuple[bool, str]: """Process a single document pair.""" # Check if output already exists output_path = output_dir / doc_pair.relative_path if skip_existing and output_path.exists(): return True, f"Skipped (already exists): {doc_pair.relative_path}" try: # Read the markdown content md_content = doc_pair.md_path.read_text(encoding='utf-8') # Render the single PDF page pdf_image = render_single_page_pdf(doc_pair.pdf_path) # Clean with ChatGPT cleaned_result = clean_document_with_chatgpt( client, model, md_content, pdf_image, verbose ) # Create output directory if needed output_path.parent.mkdir(parents=True, exist_ok=True) # Write cleaned text output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8') # Also write metadata metadata_path = output_path.with_suffix('.json') metadata = { 'original_md': str(doc_pair.md_path), 'original_pdf': str(doc_pair.pdf_path), 'confidence_score': cleaned_result.confidence_score, 'corrections_made': cleaned_result.corrections_made, 'model': model, 'pages_rendered': 1 } metadata_path.write_text(json.dumps(metadata, indent=2), encoding='utf-8') return True, f"Processed: {doc_pair.relative_path} (confidence: {cleaned_result.confidence_score:.2f})" except Exception as e: return False, f"Error processing {doc_pair.relative_path}: {e}" def main(): args = parse_args() # Validate API key if not args.openai_api_key: print("Error: OpenAI API key is required. Set via --openai-api-key or OPENAI_API_KEY environment variable.") sys.exit(1) # Initialize OpenAI client client = OpenAI(api_key=args.openai_api_key) # Set up paths input_dir = Path(args.input_dir) output_dir = Path(args.output_dir) if not input_dir.exists(): print(f"Error: Input directory {input_dir} does not exist.") sys.exit(1) output_dir.mkdir(parents=True, exist_ok=True) # Find all document pairs (single-page PDFs only) print(f"Scanning {input_dir} for single-page document pairs...") doc_pairs = find_document_pairs(input_dir, args.verbose) print(f"Found {len(doc_pairs)} valid single-page document pairs.") if not doc_pairs: print("No document pairs found.") return # Shuffle with fixed seed random.seed(args.seed) random.shuffle(doc_pairs) # Limit if requested if args.max_documents: doc_pairs = doc_pairs[:args.max_documents] print(f"Processing first {args.max_documents} documents after shuffling.") # Process documents in batches successful = 0 failed = 0 with ThreadPoolExecutor(max_workers=args.batch_size) as executor: futures = [] for doc_pair in doc_pairs: future = executor.submit( process_document, doc_pair, client, args.model, output_dir, args.skip_existing, args.verbose ) futures.append(future) # Process results with progress bar with tqdm(total=len(futures), desc="Processing documents") as pbar: for future in as_completed(futures): success, message = future.result() if success: successful += 1 else: failed += 1 if args.verbose: tqdm.write(message) pbar.update(1) pbar.set_postfix({ 'successful': successful, 'failed': failed }) # Print summary print(f"\nProcessing complete:") print(f" Successful: {successful}") print(f" Failed: {failed}") print(f" Output directory: {output_dir}") if __name__ == "__main__": main()