Cleaned up things

This commit is contained in:
Jake Poznanski 2025-09-03 20:23:01 +00:00
parent b689a8e5f8
commit bade86fe91
2 changed files with 72 additions and 14 deletions

View File

@ -17,6 +17,7 @@ from rapidfuzz import distance
import sys
import torch
import torch.distributed as dist
import numpy as np
import wandb
from torch.utils.data import Dataset, DataLoader
@ -45,6 +46,27 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def get_rank():
"""Get the rank of the current process in distributed training."""
# Check environment variables for rank information
rank = 0
# Try different environment variables that might contain rank
if "LOCAL_RANK" in os.environ:
rank = int(os.environ["LOCAL_RANK"])
elif "RANK" in os.environ:
rank = int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
return rank
def is_main_process():
"""Check if this is the main process (rank 0)."""
return get_rank() == 0
class OlmOCRBenchDataset(Dataset):
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
@ -622,6 +644,14 @@ def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], comp
def main():
# Log rank information early
rank = get_rank()
if "LOCAL_RANK" in os.environ:
logger.info(f"LOCAL_RANK environment variable: {os.environ['LOCAL_RANK']}")
if "RANK" in os.environ:
logger.info(f"RANK environment variable: {os.environ['RANK']}")
logger.info(f"Current process rank: {rank}, is_main_process: {is_main_process()}")
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
parser.add_argument(
"--train_bench_data_folder",
@ -779,14 +809,18 @@ def main():
# Set up output directory
os.makedirs(args.output_dir, exist_ok=True)
# Initialize wandb if enabled
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config=vars(args)
)
logger.info(f"Initialized wandb project: {args.wandb_project}")
report_to = ["wandb"]
# Initialize wandb only on the main process (rank 0)
if is_main_process():
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config=vars(args)
)
logger.info(f"Initialized wandb project: {args.wandb_project} (rank {get_rank()})")
report_to = ["wandb"]
else:
logger.info(f"Skipping wandb initialization on rank {get_rank()}")
report_to = [] # No reporting for non-main processes
# Verify train bench_data_folder exists
@ -946,8 +980,9 @@ def main():
logger.info("Training completed successfully!")
# Close wandb
wandb.finish()
# Close wandb only on main process
if is_main_process():
wandb.finish()
except Exception as e:
logger.error(f"Training failed: {e}")

View File

@ -25,7 +25,12 @@ class CleanedDocument(BaseModel):
cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
is_page_all_blank: bool = Field(description="Document consistents entire of blank page, or only headers/footers that would otherwise be removed")
is_page_all_blank: bool = Field(description="Document consists entirely of blank page, or only headers/footers that would otherwise be removed")
primary_language: str = Field(default="en", description="Primary language of the document (ISO 639-1 code, e.g. 'en' for English, 'es' for Spanish)")
is_rotation_valid: bool = Field(default=True, description="Whether the page orientation/rotation appears correct")
rotation_correction: int = Field(default=0, description="Degrees of rotation needed to correct orientation (0, 90, 180, or 270)")
is_table: bool = Field(default=False, description="Whether the page primarily contains a table")
is_diagram: bool = Field(default=False, description="Whether the page primarily contains a diagram or figure")
@dataclass
@ -238,11 +243,23 @@ def process_document(
# Create output directory if needed
output_path.parent.mkdir(parents=True, exist_ok=True)
# Write cleaned text
# Prepare front matter
front_matter = f"""---
primary_language: {cleaned_result.primary_language}
is_rotation_valid: {str(cleaned_result.is_rotation_valid)}
rotation_correction: {cleaned_result.rotation_correction}
is_table: {str(cleaned_result.is_table)}
is_diagram: {str(cleaned_result.is_diagram)}
---"""
# Write cleaned text with front matter
if cleaned_result.is_page_all_blank:
output_path.write_text("", encoding='utf-8')
# For blank pages, write only the front matter, ending exactly after ---
output_path.write_text(front_matter, encoding='utf-8')
else:
output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8')
# Add front matter and cleaned text with a newline separator
full_content = front_matter + "\n" + cleaned_result.cleaned_text
output_path.write_text(full_content, encoding='utf-8')
# Create soft link for the original MD file as .md.orig
orig_md_link_path = output_path.with_suffix('.md.orig')
@ -263,6 +280,12 @@ def process_document(
'original_pdf': str(doc_pair.pdf_path),
'confidence_score': cleaned_result.confidence_score,
'corrections_made': cleaned_result.corrections_made,
'is_page_all_blank': cleaned_result.is_page_all_blank,
'primary_language': cleaned_result.primary_language,
'is_rotation_valid': cleaned_result.is_rotation_valid,
'rotation_correction': cleaned_result.rotation_correction,
'is_table': cleaned_result.is_table,
'is_diagram': cleaned_result.is_diagram,
'model': model,
'pages_rendered': 1
}