diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index 43d7b3b..ea59445 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -17,6 +17,7 @@ from rapidfuzz import distance import sys import torch +import torch.distributed as dist import numpy as np import wandb from torch.utils.data import Dataset, DataLoader @@ -45,6 +46,27 @@ logging.basicConfig( logger = logging.getLogger(__name__) +def get_rank(): + """Get the rank of the current process in distributed training.""" + # Check environment variables for rank information + rank = 0 + + # Try different environment variables that might contain rank + if "LOCAL_RANK" in os.environ: + rank = int(os.environ["LOCAL_RANK"]) + elif "RANK" in os.environ: + rank = int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + + return rank + + +def is_main_process(): + """Check if this is the main process (rank 0).""" + return get_rank() == 0 + + class OlmOCRBenchDataset(Dataset): """Dataset for loading PDF pages from Olmocr-bench format JSONL files.""" @@ -622,6 +644,14 @@ def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], comp def main(): + # Log rank information early + rank = get_rank() + if "LOCAL_RANK" in os.environ: + logger.info(f"LOCAL_RANK environment variable: {os.environ['LOCAL_RANK']}") + if "RANK" in os.environ: + logger.info(f"RANK environment variable: {os.environ['RANK']}") + logger.info(f"Current process rank: {rank}, is_main_process: {is_main_process()}") + parser = argparse.ArgumentParser(description="GRPO training for OlmOCR") parser.add_argument( "--train_bench_data_folder", @@ -779,14 +809,18 @@ def main(): # Set up output directory os.makedirs(args.output_dir, exist_ok=True) - # Initialize wandb if enabled - wandb.init( - project=args.wandb_project, - name=args.wandb_run_name, - config=vars(args) - ) - logger.info(f"Initialized wandb project: {args.wandb_project}") - report_to = ["wandb"] + # Initialize wandb only on the main process (rank 0) + if is_main_process(): + wandb.init( + project=args.wandb_project, + name=args.wandb_run_name, + config=vars(args) + ) + logger.info(f"Initialized wandb project: {args.wandb_project} (rank {get_rank()})") + report_to = ["wandb"] + else: + logger.info(f"Skipping wandb initialization on rank {get_rank()}") + report_to = [] # No reporting for non-main processes # Verify train bench_data_folder exists @@ -946,8 +980,9 @@ def main(): logger.info("Training completed successfully!") - # Close wandb - wandb.finish() + # Close wandb only on main process + if is_main_process(): + wandb.finish() except Exception as e: logger.error(f"Training failed: {e}") diff --git a/scripts/clean_olmocrmix.py b/scripts/clean_olmocrmix.py index f2ab222..cfe6c76 100755 --- a/scripts/clean_olmocrmix.py +++ b/scripts/clean_olmocrmix.py @@ -25,7 +25,12 @@ class CleanedDocument(BaseModel): cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription") confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0) corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text") - is_page_all_blank: bool = Field(description="Document consistents entire of blank page, or only headers/footers that would otherwise be removed") + is_page_all_blank: bool = Field(description="Document consists entirely of blank page, or only headers/footers that would otherwise be removed") + primary_language: str = Field(default="en", description="Primary language of the document (ISO 639-1 code, e.g. 'en' for English, 'es' for Spanish)") + is_rotation_valid: bool = Field(default=True, description="Whether the page orientation/rotation appears correct") + rotation_correction: int = Field(default=0, description="Degrees of rotation needed to correct orientation (0, 90, 180, or 270)") + is_table: bool = Field(default=False, description="Whether the page primarily contains a table") + is_diagram: bool = Field(default=False, description="Whether the page primarily contains a diagram or figure") @dataclass @@ -238,11 +243,23 @@ def process_document( # Create output directory if needed output_path.parent.mkdir(parents=True, exist_ok=True) - # Write cleaned text + # Prepare front matter + front_matter = f"""--- +primary_language: {cleaned_result.primary_language} +is_rotation_valid: {str(cleaned_result.is_rotation_valid)} +rotation_correction: {cleaned_result.rotation_correction} +is_table: {str(cleaned_result.is_table)} +is_diagram: {str(cleaned_result.is_diagram)} +---""" + + # Write cleaned text with front matter if cleaned_result.is_page_all_blank: - output_path.write_text("", encoding='utf-8') + # For blank pages, write only the front matter, ending exactly after --- + output_path.write_text(front_matter, encoding='utf-8') else: - output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8') + # Add front matter and cleaned text with a newline separator + full_content = front_matter + "\n" + cleaned_result.cleaned_text + output_path.write_text(full_content, encoding='utf-8') # Create soft link for the original MD file as .md.orig orig_md_link_path = output_path.with_suffix('.md.orig') @@ -263,6 +280,12 @@ def process_document( 'original_pdf': str(doc_pair.pdf_path), 'confidence_score': cleaned_result.confidence_score, 'corrections_made': cleaned_result.corrections_made, + 'is_page_all_blank': cleaned_result.is_page_all_blank, + 'primary_language': cleaned_result.primary_language, + 'is_rotation_valid': cleaned_result.is_rotation_valid, + 'rotation_correction': cleaned_result.rotation_correction, + 'is_table': cleaned_result.is_table, + 'is_diagram': cleaned_result.is_diagram, 'model': model, 'pages_rendered': 1 }