mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			392 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			392 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| # Takes a dataset location in olmocr-mix format, (ex. a nested directory structure folder/subfolder/document.md with a corresponding folder/subfolder/document.pdf)
 | |
| # Then, it will randomly shuffle these (with a fixed seed), and prompt chatgpt to clean up the transcription, and output a cleaned document
 | |
| # Uses structured output to get a good result, then writes things back in the same format in a new root folder, preserving the original folder structure
 | |
| 
 | |
| import argparse
 | |
| import json
 | |
| import os
 | |
| import random
 | |
| import sys
 | |
| from pathlib import Path
 | |
| from typing import List, Tuple, Any, Dict, Optional
 | |
| from dataclasses import dataclass
 | |
| from concurrent.futures import ThreadPoolExecutor, as_completed
 | |
| 
 | |
| from pypdf import PdfReader
 | |
| from olmocr.data.renderpdf import render_pdf_to_base64png
 | |
| from openai import OpenAI
 | |
| from pydantic import BaseModel, Field
 | |
| from tqdm import tqdm
 | |
| 
 | |
| 
 | |
| # Structured output model for ChatGPT response
 | |
| class CleanedDocument(BaseModel):
 | |
|     cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
 | |
|     confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
 | |
|     corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
 | |
|     is_page_all_blank: bool = Field(description="Document consists entirely of blank page, or only headers/footers that would otherwise be removed")
 | |
|     primary_language: str = Field(default="en", description="Primary language of the document (ISO 639-1 code, e.g. 'en' for English, 'es' for Spanish)")
 | |
|     is_rotation_valid: bool = Field(default=True, description="Whether the page orientation/rotation appears correct")
 | |
|     rotation_correction: int = Field(default=0, description="Degrees of rotation needed to correct orientation (0, 90, 180, or 270)")
 | |
|     is_table: bool = Field(default=False, description="Whether the page primarily contains a table")
 | |
|     is_diagram: bool = Field(default=False, description="Whether the page primarily contains a diagram or figure")
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class DocumentPair:
 | |
|     md_path: Path
 | |
|     pdf_path: Path
 | |
|     relative_path: Path  # Relative path from root for preserving structure
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="Clean OCR transcriptions using ChatGPT with visual PDF context"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "input_dir",
 | |
|         help="Input directory containing olmocr-mix format data (MD files with corresponding PDFs)"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "output_dir",
 | |
|         help="Output directory for cleaned documents (preserves folder structure)"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--openai-api-key",
 | |
|         help="OpenAI API key (can also be set via OPENAI_API_KEY environment variable)",
 | |
|         default=os.getenv("OPENAI_API_KEY")
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--model",
 | |
|         default="gpt-4o-2024-08-06",
 | |
|         help="OpenAI model to use (default: gpt-4o-mini)"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--seed",
 | |
|         type=int,
 | |
|         default=42,
 | |
|         help="Random seed for shuffling documents (default: 42)"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--batch-size",
 | |
|         type=int,
 | |
|         default=10,
 | |
|         help="Number of documents to process in parallel (default: 10)"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--max-documents",
 | |
|         type=int,
 | |
|         help="Maximum number of documents to process (useful for testing)"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--skip-existing",
 | |
|         action="store_true",
 | |
|         help="Skip documents that already have cleaned versions in the output directory"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--verbose",
 | |
|         action="store_true",
 | |
|         help="Enable verbose output"
 | |
|     )
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def check_single_page_pdf(pdf_path: Path) -> bool:
 | |
|     """Check if a PDF has exactly one page."""
 | |
|     try:
 | |
|         with open(pdf_path, 'rb') as pdf_file:
 | |
|             pdf_reader = PdfReader(pdf_file)
 | |
|             return len(pdf_reader.pages) == 1
 | |
|     except Exception as e:
 | |
|         print(f"Error checking PDF {pdf_path}: {e}")
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def find_document_pairs(input_dir: Path, verbose: bool = False) -> List[DocumentPair]:
 | |
|     """Find all MD files with corresponding PDF files."""
 | |
|     pairs = []
 | |
|     skipped_no_pdf = 0
 | |
|     
 | |
|     for md_path in input_dir.rglob("*.md"):
 | |
|         # Check for corresponding PDF
 | |
|         pdf_path = md_path.with_suffix(".pdf")
 | |
|         if not pdf_path.exists():
 | |
|             if verbose:
 | |
|                 print(f"Warning: No PDF found for {md_path}")
 | |
|             skipped_no_pdf += 1
 | |
|             continue
 | |
|             
 | |
|         relative_path = md_path.relative_to(input_dir)
 | |
|         pairs.append(DocumentPair(md_path, pdf_path, relative_path))
 | |
|     
 | |
|     if skipped_no_pdf > 0:
 | |
|         print(f"Skipped {skipped_no_pdf} files without PDFs")
 | |
|     
 | |
|     return pairs
 | |
| 
 | |
| 
 | |
| def render_single_page_pdf(pdf_path: Path) -> str:
 | |
|     """Render a single-page PDF to base64 PNG image."""
 | |
|     try:
 | |
|         # Use render_pdf_to_base64png with target_longest_image_dim=2048
 | |
|         base64_png = render_pdf_to_base64png(
 | |
|             str(pdf_path),
 | |
|             1,  # Always page 1 since we validated it's a single-page PDF
 | |
|             target_longest_image_dim=2048
 | |
|         )
 | |
|         return base64_png
 | |
|     except Exception as e:
 | |
|         raise RuntimeError(f"Could not render PDF {pdf_path}: {e}")
 | |
| 
 | |
| 
 | |
| def clean_document_with_chatgpt(
 | |
|     client: OpenAI,
 | |
|     model: str,
 | |
|     md_content: str,
 | |
|     pdf_image: str,
 | |
|     verbose: bool = False
 | |
| ) -> CleanedDocument:
 | |
|     """Use ChatGPT to clean the OCR transcription with PDF context."""
 | |
|     
 | |
|     # Prepare the messages
 | |
|     messages: List[Dict[str, Any]] = [
 | |
|         {
 | |
|             "role": "system",
 | |
|             "content": (
 | |
|                 "You are an expert at cleaning and correcting OCR transcriptions. "
 | |
|                 "You will be given an OCR transcription and an image of the original PDF page. "
 | |
|                 "Your task is to:\n"
 | |
|                 "1. Correct formatting issues.\n"
 | |
|                 "2. Preserve the exact spelling of words from the original document.\n"
 | |
|                 "3. Remove any original transcriber's marks and notes, usually indicated by [ and ] symbols.\n"
 | |
|                 "4. Fix word breaks and line breaks\n"
 | |
|                 "5. Ensure mathematical formulas and special characters are correct\n"
 | |
|                 "6. If there are any figures or charts, label them with the following markdown syntax \n"
 | |
|                 "7. Maintain the semantic structure of the document\n"
 | |
|                 "8. Remove any headers or footers that are not semantically relevant to the main document contents, ex page numbers, document classifications, etc.\n"
 | |
|                 "9. Convert tables into HTML format. Keep the syntax simple, but use <th> for header rows, and use rowspan and colspans appropriately. Don't use <br> inside of table cells, just split that into new rows as needed. Do NOT use LaTeX or Markdown table syntax.\n"
 | |
|                 "10. If the page is blank, you are allowed to return 'null' for the text.\n"
 | |
|                 "Return a cleaned version that accurately represents the original document."
 | |
|             )
 | |
|         }
 | |
|     ]
 | |
|     
 | |
|     # Add the content with the PDF image
 | |
|     content: List[Dict[str, Any]] = [
 | |
|         {
 | |
|             "type": "text",
 | |
|             "text": f"Please clean the following OCR transcription based on the provided PDF page image:\n\n{md_content}"
 | |
|         },
 | |
|         {
 | |
|             "type": "image_url",
 | |
|             "image_url": {
 | |
|                 "url": f"data:image/png;base64,{pdf_image}"
 | |
|             }
 | |
|         }
 | |
|     ]
 | |
|     
 | |
|     messages.append({
 | |
|         "role": "user",
 | |
|         "content": content
 | |
|     })
 | |
|     
 | |
|     # Make the API call with structured output
 | |
|     try:
 | |
|         response = client.beta.chat.completions.parse(
 | |
|             model=model,
 | |
|             messages=messages,  # type: ignore
 | |
|             response_format=CleanedDocument,
 | |
|             temperature=0.2,  # Lower temperature for more consistent cleaning
 | |
|             max_tokens=16384
 | |
|         )
 | |
|         
 | |
|         parsed_result = response.choices[0].message.parsed
 | |
|         if parsed_result is None:
 | |
|             raise ValueError("ChatGPT returned no parsed result")
 | |
|         return parsed_result
 | |
|     except Exception as e:
 | |
|         print(f"Error calling ChatGPT: {e}")
 | |
|         raise
 | |
| 
 | |
| 
 | |
| def process_document(
 | |
|     doc_pair: DocumentPair,
 | |
|     client: OpenAI,
 | |
|     model: str,
 | |
|     output_dir: Path,
 | |
|     skip_existing: bool,
 | |
|     verbose: bool
 | |
| ) -> Tuple[bool, str]:
 | |
|     """Process a single document pair."""
 | |
|     
 | |
|     # Check if output already exists
 | |
|     output_path = output_dir / doc_pair.relative_path
 | |
|     if skip_existing and output_path.exists():
 | |
|         return True, f"Skipped (already exists): {doc_pair.relative_path}"
 | |
|     
 | |
|     try:
 | |
|         # Check if PDF has exactly one page
 | |
|         if not check_single_page_pdf(doc_pair.pdf_path):
 | |
|             return False, f"Skipped multi-page PDF: {doc_pair.pdf_path}"
 | |
|         
 | |
|         # Read the markdown content
 | |
|         md_content = doc_pair.md_path.read_text(encoding='utf-8')
 | |
|         
 | |
|         # Render the single PDF page
 | |
|         pdf_image = render_single_page_pdf(doc_pair.pdf_path)
 | |
|         
 | |
|         # Clean with ChatGPT
 | |
|         cleaned_result = clean_document_with_chatgpt(
 | |
|             client, model, md_content, pdf_image, verbose
 | |
|         )
 | |
|         
 | |
|         # Create output directory if needed
 | |
|         output_path.parent.mkdir(parents=True, exist_ok=True)
 | |
|         
 | |
|         # Prepare front matter
 | |
|         front_matter = f"""---
 | |
| primary_language: {cleaned_result.primary_language}
 | |
| is_rotation_valid: {str(cleaned_result.is_rotation_valid)}
 | |
| rotation_correction: {cleaned_result.rotation_correction}
 | |
| is_table: {str(cleaned_result.is_table)}
 | |
| is_diagram: {str(cleaned_result.is_diagram)}
 | |
| ---"""
 | |
|         
 | |
|         # Write cleaned text with front matter
 | |
|         if cleaned_result.is_page_all_blank:
 | |
|             # For blank pages, write only the front matter, ending exactly after ---
 | |
|             output_path.write_text(front_matter, encoding='utf-8')
 | |
|         else:
 | |
|             # Add front matter and cleaned text with a newline separator
 | |
|             full_content = front_matter + "\n" + cleaned_result.cleaned_text
 | |
|             output_path.write_text(full_content, encoding='utf-8')
 | |
|         
 | |
|         # Create soft link for the original MD file as .md.orig
 | |
|         orig_md_link_path = output_path.with_suffix('.md.orig')
 | |
|         if orig_md_link_path.exists() or orig_md_link_path.is_symlink():
 | |
|             orig_md_link_path.unlink()
 | |
|         orig_md_link_path.symlink_to(doc_pair.md_path.absolute())
 | |
|         
 | |
|         # Create soft link for the PDF file
 | |
|         pdf_link_path = output_dir / doc_pair.relative_path.with_suffix('.pdf')
 | |
|         if pdf_link_path.exists() or pdf_link_path.is_symlink():
 | |
|             pdf_link_path.unlink()
 | |
|         pdf_link_path.symlink_to(doc_pair.pdf_path.absolute())
 | |
|         
 | |
|         # Also write metadata
 | |
|         metadata_path = output_path.with_suffix('.json')
 | |
|         metadata = {
 | |
|             'original_md': str(doc_pair.md_path),
 | |
|             'original_pdf': str(doc_pair.pdf_path),
 | |
|             'confidence_score': cleaned_result.confidence_score,
 | |
|             'corrections_made': cleaned_result.corrections_made,
 | |
|             'is_page_all_blank': cleaned_result.is_page_all_blank,
 | |
|             'primary_language': cleaned_result.primary_language,
 | |
|             'is_rotation_valid': cleaned_result.is_rotation_valid,
 | |
|             'rotation_correction': cleaned_result.rotation_correction,
 | |
|             'is_table': cleaned_result.is_table,
 | |
|             'is_diagram': cleaned_result.is_diagram,
 | |
|             'model': model,
 | |
|             'pages_rendered': 1
 | |
|         }
 | |
|         metadata_path.write_text(json.dumps(metadata, indent=2), encoding='utf-8')
 | |
|         
 | |
|         return True, f"Processed: {doc_pair.relative_path} (confidence: {cleaned_result.confidence_score:.2f})"
 | |
|         
 | |
|     except Exception as e:
 | |
|         return False, f"Error processing {doc_pair.relative_path}: {e}"
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     args = parse_args()
 | |
|     
 | |
|     # Validate API key
 | |
|     if not args.openai_api_key:
 | |
|         print("Error: OpenAI API key is required. Set via --openai-api-key or OPENAI_API_KEY environment variable.")
 | |
|         sys.exit(1)
 | |
|     
 | |
|     # Initialize OpenAI client
 | |
|     client = OpenAI(api_key=args.openai_api_key)
 | |
|     
 | |
|     # Set up paths
 | |
|     input_dir = Path(args.input_dir)
 | |
|     output_dir = Path(args.output_dir)
 | |
|     
 | |
|     if not input_dir.exists():
 | |
|         print(f"Error: Input directory {input_dir} does not exist.")
 | |
|         sys.exit(1)
 | |
|     
 | |
|     output_dir.mkdir(parents=True, exist_ok=True)
 | |
|     
 | |
|     # Find all document pairs
 | |
|     print(f"Scanning {input_dir} for document pairs...")
 | |
|     doc_pairs = find_document_pairs(input_dir, args.verbose)
 | |
|     print(f"Found {len(doc_pairs)} document pairs (will check page count during processing).")
 | |
|     
 | |
|     if not doc_pairs:
 | |
|         print("No document pairs found.")
 | |
|         return
 | |
|     
 | |
|     # Shuffle with fixed seed
 | |
|     random.seed(args.seed)
 | |
|     random.shuffle(doc_pairs)
 | |
|     
 | |
|     # Limit if requested
 | |
|     if args.max_documents:
 | |
|         doc_pairs = doc_pairs[:args.max_documents]
 | |
|         print(f"Processing first {args.max_documents} documents after shuffling.")
 | |
|     
 | |
|     # Process documents in batches
 | |
|     successful = 0
 | |
|     failed = 0
 | |
|     skipped_multi_page = 0
 | |
|     
 | |
|     with ThreadPoolExecutor(max_workers=args.batch_size) as executor:
 | |
|         futures = []
 | |
|         
 | |
|         for doc_pair in doc_pairs:
 | |
|             future = executor.submit(
 | |
|                 process_document,
 | |
|                 doc_pair,
 | |
|                 client,
 | |
|                 args.model,
 | |
|                 output_dir,
 | |
|                 args.skip_existing,
 | |
|                 args.verbose
 | |
|             )
 | |
|             futures.append(future)
 | |
|         
 | |
|         # Process results with progress bar
 | |
|         with tqdm(total=len(futures), desc="Processing documents") as pbar:
 | |
|             for future in as_completed(futures):
 | |
|                 success, message = future.result()
 | |
|                 if success:
 | |
|                     successful += 1
 | |
|                 else:
 | |
|                     if "multi-page" in message.lower():
 | |
|                         skipped_multi_page += 1
 | |
|                     else:
 | |
|                         failed += 1
 | |
|                 
 | |
|                 if args.verbose:
 | |
|                     tqdm.write(message)
 | |
|                 
 | |
|                 pbar.update(1)
 | |
|                 pbar.set_postfix({
 | |
|                     'successful': successful,
 | |
|                     'skipped': skipped_multi_page,
 | |
|                     'failed': failed
 | |
|                 })
 | |
|     
 | |
|     # Print summary
 | |
|     print(f"\nProcessing complete:")
 | |
|     print(f"  Successful: {successful}")
 | |
|     print(f"  Skipped (multi-page): {skipped_multi_page}")
 | |
|     print(f"  Failed (other errors): {failed}")
 | |
|     print(f"  Output directory: {output_dir}")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 | 
