mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-19 03:59:09 +00:00
Cleaned up things
This commit is contained in:
parent
b689a8e5f8
commit
bade86fe91
@ -17,6 +17,7 @@ from rapidfuzz import distance
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import wandb
|
import wandb
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
@ -45,6 +46,27 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
"""Get the rank of the current process in distributed training."""
|
||||||
|
# Check environment variables for rank information
|
||||||
|
rank = 0
|
||||||
|
|
||||||
|
# Try different environment variables that might contain rank
|
||||||
|
if "LOCAL_RANK" in os.environ:
|
||||||
|
rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
elif "RANK" in os.environ:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
elif dist.is_available() and dist.is_initialized():
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
return rank
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
"""Check if this is the main process (rank 0)."""
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
class OlmOCRBenchDataset(Dataset):
|
class OlmOCRBenchDataset(Dataset):
|
||||||
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
|
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
|
||||||
|
|
||||||
@ -622,6 +644,14 @@ def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], comp
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Log rank information early
|
||||||
|
rank = get_rank()
|
||||||
|
if "LOCAL_RANK" in os.environ:
|
||||||
|
logger.info(f"LOCAL_RANK environment variable: {os.environ['LOCAL_RANK']}")
|
||||||
|
if "RANK" in os.environ:
|
||||||
|
logger.info(f"RANK environment variable: {os.environ['RANK']}")
|
||||||
|
logger.info(f"Current process rank: {rank}, is_main_process: {is_main_process()}")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
|
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_bench_data_folder",
|
"--train_bench_data_folder",
|
||||||
@ -779,14 +809,18 @@ def main():
|
|||||||
# Set up output directory
|
# Set up output directory
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize wandb if enabled
|
# Initialize wandb only on the main process (rank 0)
|
||||||
wandb.init(
|
if is_main_process():
|
||||||
project=args.wandb_project,
|
wandb.init(
|
||||||
name=args.wandb_run_name,
|
project=args.wandb_project,
|
||||||
config=vars(args)
|
name=args.wandb_run_name,
|
||||||
)
|
config=vars(args)
|
||||||
logger.info(f"Initialized wandb project: {args.wandb_project}")
|
)
|
||||||
report_to = ["wandb"]
|
logger.info(f"Initialized wandb project: {args.wandb_project} (rank {get_rank()})")
|
||||||
|
report_to = ["wandb"]
|
||||||
|
else:
|
||||||
|
logger.info(f"Skipping wandb initialization on rank {get_rank()}")
|
||||||
|
report_to = [] # No reporting for non-main processes
|
||||||
|
|
||||||
|
|
||||||
# Verify train bench_data_folder exists
|
# Verify train bench_data_folder exists
|
||||||
@ -946,8 +980,9 @@ def main():
|
|||||||
|
|
||||||
logger.info("Training completed successfully!")
|
logger.info("Training completed successfully!")
|
||||||
|
|
||||||
# Close wandb
|
# Close wandb only on main process
|
||||||
wandb.finish()
|
if is_main_process():
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {e}")
|
logger.error(f"Training failed: {e}")
|
||||||
|
@ -25,7 +25,12 @@ class CleanedDocument(BaseModel):
|
|||||||
cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
|
cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
|
||||||
confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
|
confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
|
||||||
corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
|
corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
|
||||||
is_page_all_blank: bool = Field(description="Document consistents entire of blank page, or only headers/footers that would otherwise be removed")
|
is_page_all_blank: bool = Field(description="Document consists entirely of blank page, or only headers/footers that would otherwise be removed")
|
||||||
|
primary_language: str = Field(default="en", description="Primary language of the document (ISO 639-1 code, e.g. 'en' for English, 'es' for Spanish)")
|
||||||
|
is_rotation_valid: bool = Field(default=True, description="Whether the page orientation/rotation appears correct")
|
||||||
|
rotation_correction: int = Field(default=0, description="Degrees of rotation needed to correct orientation (0, 90, 180, or 270)")
|
||||||
|
is_table: bool = Field(default=False, description="Whether the page primarily contains a table")
|
||||||
|
is_diagram: bool = Field(default=False, description="Whether the page primarily contains a diagram or figure")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -238,11 +243,23 @@ def process_document(
|
|||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Write cleaned text
|
# Prepare front matter
|
||||||
|
front_matter = f"""---
|
||||||
|
primary_language: {cleaned_result.primary_language}
|
||||||
|
is_rotation_valid: {str(cleaned_result.is_rotation_valid)}
|
||||||
|
rotation_correction: {cleaned_result.rotation_correction}
|
||||||
|
is_table: {str(cleaned_result.is_table)}
|
||||||
|
is_diagram: {str(cleaned_result.is_diagram)}
|
||||||
|
---"""
|
||||||
|
|
||||||
|
# Write cleaned text with front matter
|
||||||
if cleaned_result.is_page_all_blank:
|
if cleaned_result.is_page_all_blank:
|
||||||
output_path.write_text("", encoding='utf-8')
|
# For blank pages, write only the front matter, ending exactly after ---
|
||||||
|
output_path.write_text(front_matter, encoding='utf-8')
|
||||||
else:
|
else:
|
||||||
output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8')
|
# Add front matter and cleaned text with a newline separator
|
||||||
|
full_content = front_matter + "\n" + cleaned_result.cleaned_text
|
||||||
|
output_path.write_text(full_content, encoding='utf-8')
|
||||||
|
|
||||||
# Create soft link for the original MD file as .md.orig
|
# Create soft link for the original MD file as .md.orig
|
||||||
orig_md_link_path = output_path.with_suffix('.md.orig')
|
orig_md_link_path = output_path.with_suffix('.md.orig')
|
||||||
@ -263,6 +280,12 @@ def process_document(
|
|||||||
'original_pdf': str(doc_pair.pdf_path),
|
'original_pdf': str(doc_pair.pdf_path),
|
||||||
'confidence_score': cleaned_result.confidence_score,
|
'confidence_score': cleaned_result.confidence_score,
|
||||||
'corrections_made': cleaned_result.corrections_made,
|
'corrections_made': cleaned_result.corrections_made,
|
||||||
|
'is_page_all_blank': cleaned_result.is_page_all_blank,
|
||||||
|
'primary_language': cleaned_result.primary_language,
|
||||||
|
'is_rotation_valid': cleaned_result.is_rotation_valid,
|
||||||
|
'rotation_correction': cleaned_result.rotation_correction,
|
||||||
|
'is_table': cleaned_result.is_table,
|
||||||
|
'is_diagram': cleaned_result.is_diagram,
|
||||||
'model': model,
|
'model': model,
|
||||||
'pages_rendered': 1
|
'pages_rendered': 1
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user