diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index ce53891..cb5e227 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -62,11 +62,58 @@ class OlmOCRBenchDataset(Dataset): if not os.path.exists(self.pdf_folder): raise ValueError(f"PDFs folder not found at {self.pdf_folder}") + # Set claude_original folder path + self.claude_original_folder = os.path.join(bench_data_folder, "claude_original") + if os.path.exists(self.claude_original_folder): + logger.info(f"Found claude_original folder at {self.claude_original_folder}") + else: + logger.warning(f"No claude_original folder found at {self.claude_original_folder}") + # Load unique PDFs from JSONL files self.samples = self._load_unique_pdfs_from_jsonl() logger.info(f"Created dataset with {len(self.samples)} unique PDF samples") + def _load_claude_original(self, pdf_name: str, page: int) -> Optional[str]: + """Load the claude_original markdown file for a given PDF and page.""" + if not os.path.exists(self.claude_original_folder): + return None + + # Extract the base PDF name and construct the expected filename + # pdf_name like "s2pdf/pdf_00017_page2.pdf" -> construct the markdown filename + pdf_base = os.path.basename(pdf_name).replace(".pdf", "") + + # Handle case where page is already in the filename + if "_page" in pdf_base: + pdf_base_parts = pdf_base.split("_page") + pdf_base_name = pdf_base_parts[0] + # Use the page from the filename if it exists + page_from_name = int(pdf_base_parts[1]) if len(pdf_base_parts) > 1 and pdf_base_parts[1].isdigit() else page + else: + pdf_base_name = pdf_base + page_from_name = page + + # Extract folder structure from pdf_name (e.g., "s2pdf/" or "arxiv_math/") + pdf_dir = os.path.dirname(pdf_name) + + # Construct the expected claude_original filename + # Format: pdf_00017_page2_pg1_repeat1.md + claude_filename = f"{pdf_base_name}_page{page_from_name}_pg1_repeat1.md" + + # Build the full path to the claude_original file + claude_file_path = os.path.join(self.claude_original_folder, pdf_dir, claude_filename) + + if os.path.exists(claude_file_path): + try: + with open(claude_file_path, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + logger.warning(f"Failed to read claude_original file {claude_file_path}: {e}") + else: + logger.debug(f"Claude original file not found: {claude_file_path}") + + return None + def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]: """Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF.""" jsonl_files = sorted(glob.glob(os.path.join(self.bench_data_folder, "*.jsonl"))) @@ -97,13 +144,15 @@ class OlmOCRBenchDataset(Dataset): if pdf_page_key not in pdf_data: # First time seeing this PDF+page pdf_path = os.path.join(self.pdf_folder, pdf_name) + claude_original = self._load_claude_original(pdf_name, page) pdf_data[pdf_page_key] = { "pdf_path": pdf_path, "pdf_name": pdf_name, "page": page, "jsonl_file": jsonl_file, "test_ids": [test_id], - "entries": [entry] + "entries": [entry], + "claude_original": claude_original } else: # Add test case to existing PDF+page @@ -167,6 +216,7 @@ class OlmOCRBenchDataset(Dataset): "jsonl_file": jsonl_file, "test_ids": test_ids, "image": image, # Include the PIL image for processing later + "claude_original": sample.get("claude_original"), # Include claude_original if available } except Exception as e: @@ -276,6 +326,61 @@ def evaluate_single_completion(args: Tuple[int, Any, str, str, List[str]]) -> Tu return i, None +def bench_edit_distance_reward(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]], **kwargs): + """ + Reward function based on edit distance similarity to claude_original files. + + Calculates the normalized edit distance between each completion and its corresponding + claude_original reference. Returns 1.0 for perfect match, lower for more distance. + + Args: + prompts: List of prompts + completions: List of generated completions (model outputs) + claude_original: List of claude_original reference texts (one per completion) + **kwargs: Additional arguments + + Returns: + List of reward scores between 0 and 1, where 1.0 is perfect match + """ + logger.info(f"Running bench edit distance reward function for {len(completions)} completions") + + rewards = [] + + for i, completion in enumerate(completions): + # Extract text from completion + if isinstance(completion, list): + comp_text = completion[0]["content"] if completion else "" + elif isinstance(completion, str): + comp_text = completion + else: + comp_text = "" + + # Get the corresponding claude_original reference + reference = claude_original[i] if i < len(claude_original) else None + + if reference is None: + logger.warning(f"No claude_original reference for completion {i}") + rewards.append(0.0) + continue + + # Calculate edit distance + dist = distance.Levenshtein.distance(comp_text, reference) + + # Calculate maximum possible distance (length of longer string) + max_dist = max(len(comp_text), len(reference)) + + # Calculate similarity (1.0 = perfect match, 0.0 = completely different) + if max_dist == 0: + similarity = 1.0 # Both empty strings + else: + similarity = 1.0 - (dist / max_dist) + + rewards.append(max(0.0, similarity)) # Ensure non-negative + + logger.info(f"Bench edit distance rewards range: [{min(rewards) if rewards else 0:.3f}, {max(rewards) if rewards else 0:.3f}]") + return rewards + + def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs): """ Reward function based on edit distance to the medoid completion. @@ -513,15 +618,27 @@ def main(): ) parser.add_argument( "--reward_bench", - action="store_true", - default=False, - help="Use bench-based reward function (test pass rate)" + nargs='?', + const=1.0, + type=float, + default=None, + help="Use bench-based reward function with optional weight (default: 1.0)" ) parser.add_argument( "--reward_medoid", - action="store_true", - default=False, - help="Use medoid-based reward function (edit distance similarity)" + nargs='?', + const=1.0, + type=float, + default=None, + help="Use medoid-based reward function with optional weight (default: 1.0)" + ) + parser.add_argument( + "--reward_bench_edit_distance", + nargs='?', + const=1.0, + type=float, + default=None, + help="Use bench edit distance reward with optional weight (default: 1.0)" ) args = parser.parse_args() @@ -632,22 +749,40 @@ def main(): log_completions=True, ) - # Build list of reward functions based on command-line arguments + # Build list of reward functions and weights based on command-line arguments reward_funcs = [] + reward_weights = [] + reward_names = [] - if args.reward_bench: + if args.reward_bench is not None: reward_funcs.append(olmocr_bench_reward) - logger.info("Added bench-based reward function") + reward_weights.append(args.reward_bench) + reward_names.append("bench") + logger.info(f"Added bench-based reward function with weight {args.reward_bench}") - if args.reward_medoid: + if args.reward_medoid is not None: reward_funcs.append(medoid_reward) - logger.info("Added medoid-based reward function") + reward_weights.append(args.reward_medoid) + reward_names.append("medoid") + logger.info(f"Added medoid-based reward function with weight {args.reward_medoid}") + + if args.reward_bench_edit_distance is not None: + reward_funcs.append(bench_edit_distance_reward) + reward_weights.append(args.reward_bench_edit_distance) + reward_names.append("bench_edit_distance") + logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}") if not reward_funcs: - logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid") + logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance") return - logger.info(f"Using {len(reward_funcs)} reward function(s)") + # Log summary of reward configuration + logger.info(f"\n" + "="*50) + logger.info(f"Reward Configuration:") + logger.info(f"Using {len(reward_funcs)} reward function(s):") + for name, weight in zip(reward_names, reward_weights): + logger.info(f" - {name}: weight={weight}") + logger.info("="*50 + "\n") # Initialize GRPO trainer logger.info("Initializing GRPO trainer") @@ -658,6 +793,7 @@ def main(): train_dataset=train_dataset, eval_dataset=eval_dataset, reward_funcs=reward_funcs, + reward_weights=reward_weights, ) # Start training