Cleaned up things

This commit is contained in:
Jake Poznanski 2025-09-03 20:23:01 +00:00
parent b689a8e5f8
commit bade86fe91
2 changed files with 72 additions and 14 deletions

View File

@ -17,6 +17,7 @@ from rapidfuzz import distance
import sys import sys
import torch import torch
import torch.distributed as dist
import numpy as np import numpy as np
import wandb import wandb
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
@ -45,6 +46,27 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_rank():
"""Get the rank of the current process in distributed training."""
# Check environment variables for rank information
rank = 0
# Try different environment variables that might contain rank
if "LOCAL_RANK" in os.environ:
rank = int(os.environ["LOCAL_RANK"])
elif "RANK" in os.environ:
rank = int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
return rank
def is_main_process():
"""Check if this is the main process (rank 0)."""
return get_rank() == 0
class OlmOCRBenchDataset(Dataset): class OlmOCRBenchDataset(Dataset):
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files.""" """Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
@ -622,6 +644,14 @@ def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], comp
def main(): def main():
# Log rank information early
rank = get_rank()
if "LOCAL_RANK" in os.environ:
logger.info(f"LOCAL_RANK environment variable: {os.environ['LOCAL_RANK']}")
if "RANK" in os.environ:
logger.info(f"RANK environment variable: {os.environ['RANK']}")
logger.info(f"Current process rank: {rank}, is_main_process: {is_main_process()}")
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR") parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
parser.add_argument( parser.add_argument(
"--train_bench_data_folder", "--train_bench_data_folder",
@ -779,14 +809,18 @@ def main():
# Set up output directory # Set up output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Initialize wandb if enabled # Initialize wandb only on the main process (rank 0)
wandb.init( if is_main_process():
project=args.wandb_project, wandb.init(
name=args.wandb_run_name, project=args.wandb_project,
config=vars(args) name=args.wandb_run_name,
) config=vars(args)
logger.info(f"Initialized wandb project: {args.wandb_project}") )
report_to = ["wandb"] logger.info(f"Initialized wandb project: {args.wandb_project} (rank {get_rank()})")
report_to = ["wandb"]
else:
logger.info(f"Skipping wandb initialization on rank {get_rank()}")
report_to = [] # No reporting for non-main processes
# Verify train bench_data_folder exists # Verify train bench_data_folder exists
@ -946,8 +980,9 @@ def main():
logger.info("Training completed successfully!") logger.info("Training completed successfully!")
# Close wandb # Close wandb only on main process
wandb.finish() if is_main_process():
wandb.finish()
except Exception as e: except Exception as e:
logger.error(f"Training failed: {e}") logger.error(f"Training failed: {e}")

View File

@ -25,7 +25,12 @@ class CleanedDocument(BaseModel):
cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription") cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0) confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text") corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
is_page_all_blank: bool = Field(description="Document consistents entire of blank page, or only headers/footers that would otherwise be removed") is_page_all_blank: bool = Field(description="Document consists entirely of blank page, or only headers/footers that would otherwise be removed")
primary_language: str = Field(default="en", description="Primary language of the document (ISO 639-1 code, e.g. 'en' for English, 'es' for Spanish)")
is_rotation_valid: bool = Field(default=True, description="Whether the page orientation/rotation appears correct")
rotation_correction: int = Field(default=0, description="Degrees of rotation needed to correct orientation (0, 90, 180, or 270)")
is_table: bool = Field(default=False, description="Whether the page primarily contains a table")
is_diagram: bool = Field(default=False, description="Whether the page primarily contains a diagram or figure")
@dataclass @dataclass
@ -238,11 +243,23 @@ def process_document(
# Create output directory if needed # Create output directory if needed
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
# Write cleaned text # Prepare front matter
front_matter = f"""---
primary_language: {cleaned_result.primary_language}
is_rotation_valid: {str(cleaned_result.is_rotation_valid)}
rotation_correction: {cleaned_result.rotation_correction}
is_table: {str(cleaned_result.is_table)}
is_diagram: {str(cleaned_result.is_diagram)}
---"""
# Write cleaned text with front matter
if cleaned_result.is_page_all_blank: if cleaned_result.is_page_all_blank:
output_path.write_text("", encoding='utf-8') # For blank pages, write only the front matter, ending exactly after ---
output_path.write_text(front_matter, encoding='utf-8')
else: else:
output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8') # Add front matter and cleaned text with a newline separator
full_content = front_matter + "\n" + cleaned_result.cleaned_text
output_path.write_text(full_content, encoding='utf-8')
# Create soft link for the original MD file as .md.orig # Create soft link for the original MD file as .md.orig
orig_md_link_path = output_path.with_suffix('.md.orig') orig_md_link_path = output_path.with_suffix('.md.orig')
@ -263,6 +280,12 @@ def process_document(
'original_pdf': str(doc_pair.pdf_path), 'original_pdf': str(doc_pair.pdf_path),
'confidence_score': cleaned_result.confidence_score, 'confidence_score': cleaned_result.confidence_score,
'corrections_made': cleaned_result.corrections_made, 'corrections_made': cleaned_result.corrections_made,
'is_page_all_blank': cleaned_result.is_page_all_blank,
'primary_language': cleaned_result.primary_language,
'is_rotation_valid': cleaned_result.is_rotation_valid,
'rotation_correction': cleaned_result.rotation_correction,
'is_table': cleaned_result.is_table,
'is_diagram': cleaned_result.is_diagram,
'model': model, 'model': model,
'pages_rendered': 1 'pages_rendered': 1
} }