mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-18 19:49:19 +00:00
Cleaned up things
This commit is contained in:
parent
b689a8e5f8
commit
bade86fe91
@ -17,6 +17,7 @@ from rapidfuzz import distance
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import wandb
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@ -45,6 +46,27 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_rank():
|
||||
"""Get the rank of the current process in distributed training."""
|
||||
# Check environment variables for rank information
|
||||
rank = 0
|
||||
|
||||
# Try different environment variables that might contain rank
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
rank = int(os.environ["LOCAL_RANK"])
|
||||
elif "RANK" in os.environ:
|
||||
rank = int(os.environ["RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
|
||||
return rank
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""Check if this is the main process (rank 0)."""
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
class OlmOCRBenchDataset(Dataset):
|
||||
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
|
||||
|
||||
@ -622,6 +644,14 @@ def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], comp
|
||||
|
||||
|
||||
def main():
|
||||
# Log rank information early
|
||||
rank = get_rank()
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
logger.info(f"LOCAL_RANK environment variable: {os.environ['LOCAL_RANK']}")
|
||||
if "RANK" in os.environ:
|
||||
logger.info(f"RANK environment variable: {os.environ['RANK']}")
|
||||
logger.info(f"Current process rank: {rank}, is_main_process: {is_main_process()}")
|
||||
|
||||
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
|
||||
parser.add_argument(
|
||||
"--train_bench_data_folder",
|
||||
@ -779,14 +809,18 @@ def main():
|
||||
# Set up output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Initialize wandb if enabled
|
||||
wandb.init(
|
||||
project=args.wandb_project,
|
||||
name=args.wandb_run_name,
|
||||
config=vars(args)
|
||||
)
|
||||
logger.info(f"Initialized wandb project: {args.wandb_project}")
|
||||
report_to = ["wandb"]
|
||||
# Initialize wandb only on the main process (rank 0)
|
||||
if is_main_process():
|
||||
wandb.init(
|
||||
project=args.wandb_project,
|
||||
name=args.wandb_run_name,
|
||||
config=vars(args)
|
||||
)
|
||||
logger.info(f"Initialized wandb project: {args.wandb_project} (rank {get_rank()})")
|
||||
report_to = ["wandb"]
|
||||
else:
|
||||
logger.info(f"Skipping wandb initialization on rank {get_rank()}")
|
||||
report_to = [] # No reporting for non-main processes
|
||||
|
||||
|
||||
# Verify train bench_data_folder exists
|
||||
@ -946,8 +980,9 @@ def main():
|
||||
|
||||
logger.info("Training completed successfully!")
|
||||
|
||||
# Close wandb
|
||||
wandb.finish()
|
||||
# Close wandb only on main process
|
||||
if is_main_process():
|
||||
wandb.finish()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
|
@ -25,7 +25,12 @@ class CleanedDocument(BaseModel):
|
||||
cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
|
||||
confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
|
||||
corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
|
||||
is_page_all_blank: bool = Field(description="Document consistents entire of blank page, or only headers/footers that would otherwise be removed")
|
||||
is_page_all_blank: bool = Field(description="Document consists entirely of blank page, or only headers/footers that would otherwise be removed")
|
||||
primary_language: str = Field(default="en", description="Primary language of the document (ISO 639-1 code, e.g. 'en' for English, 'es' for Spanish)")
|
||||
is_rotation_valid: bool = Field(default=True, description="Whether the page orientation/rotation appears correct")
|
||||
rotation_correction: int = Field(default=0, description="Degrees of rotation needed to correct orientation (0, 90, 180, or 270)")
|
||||
is_table: bool = Field(default=False, description="Whether the page primarily contains a table")
|
||||
is_diagram: bool = Field(default=False, description="Whether the page primarily contains a diagram or figure")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -238,11 +243,23 @@ def process_document(
|
||||
# Create output directory if needed
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write cleaned text
|
||||
# Prepare front matter
|
||||
front_matter = f"""---
|
||||
primary_language: {cleaned_result.primary_language}
|
||||
is_rotation_valid: {str(cleaned_result.is_rotation_valid)}
|
||||
rotation_correction: {cleaned_result.rotation_correction}
|
||||
is_table: {str(cleaned_result.is_table)}
|
||||
is_diagram: {str(cleaned_result.is_diagram)}
|
||||
---"""
|
||||
|
||||
# Write cleaned text with front matter
|
||||
if cleaned_result.is_page_all_blank:
|
||||
output_path.write_text("", encoding='utf-8')
|
||||
# For blank pages, write only the front matter, ending exactly after ---
|
||||
output_path.write_text(front_matter, encoding='utf-8')
|
||||
else:
|
||||
output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8')
|
||||
# Add front matter and cleaned text with a newline separator
|
||||
full_content = front_matter + "\n" + cleaned_result.cleaned_text
|
||||
output_path.write_text(full_content, encoding='utf-8')
|
||||
|
||||
# Create soft link for the original MD file as .md.orig
|
||||
orig_md_link_path = output_path.with_suffix('.md.orig')
|
||||
@ -263,6 +280,12 @@ def process_document(
|
||||
'original_pdf': str(doc_pair.pdf_path),
|
||||
'confidence_score': cleaned_result.confidence_score,
|
||||
'corrections_made': cleaned_result.corrections_made,
|
||||
'is_page_all_blank': cleaned_result.is_page_all_blank,
|
||||
'primary_language': cleaned_result.primary_language,
|
||||
'is_rotation_valid': cleaned_result.is_rotation_valid,
|
||||
'rotation_correction': cleaned_result.rotation_correction,
|
||||
'is_table': cleaned_result.is_table,
|
||||
'is_diagram': cleaned_result.is_diagram,
|
||||
'model': model,
|
||||
'pages_rendered': 1
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user