Adding additional rewards and weights

This commit is contained in:
Jake Poznanski 2025-08-26 21:49:35 +00:00
parent 9c520498dd
commit c86b413d3e

View File

@ -62,11 +62,58 @@ class OlmOCRBenchDataset(Dataset):
if not os.path.exists(self.pdf_folder): if not os.path.exists(self.pdf_folder):
raise ValueError(f"PDFs folder not found at {self.pdf_folder}") raise ValueError(f"PDFs folder not found at {self.pdf_folder}")
# Set claude_original folder path
self.claude_original_folder = os.path.join(bench_data_folder, "claude_original")
if os.path.exists(self.claude_original_folder):
logger.info(f"Found claude_original folder at {self.claude_original_folder}")
else:
logger.warning(f"No claude_original folder found at {self.claude_original_folder}")
# Load unique PDFs from JSONL files # Load unique PDFs from JSONL files
self.samples = self._load_unique_pdfs_from_jsonl() self.samples = self._load_unique_pdfs_from_jsonl()
logger.info(f"Created dataset with {len(self.samples)} unique PDF samples") logger.info(f"Created dataset with {len(self.samples)} unique PDF samples")
def _load_claude_original(self, pdf_name: str, page: int) -> Optional[str]:
"""Load the claude_original markdown file for a given PDF and page."""
if not os.path.exists(self.claude_original_folder):
return None
# Extract the base PDF name and construct the expected filename
# pdf_name like "s2pdf/pdf_00017_page2.pdf" -> construct the markdown filename
pdf_base = os.path.basename(pdf_name).replace(".pdf", "")
# Handle case where page is already in the filename
if "_page" in pdf_base:
pdf_base_parts = pdf_base.split("_page")
pdf_base_name = pdf_base_parts[0]
# Use the page from the filename if it exists
page_from_name = int(pdf_base_parts[1]) if len(pdf_base_parts) > 1 and pdf_base_parts[1].isdigit() else page
else:
pdf_base_name = pdf_base
page_from_name = page
# Extract folder structure from pdf_name (e.g., "s2pdf/" or "arxiv_math/")
pdf_dir = os.path.dirname(pdf_name)
# Construct the expected claude_original filename
# Format: pdf_00017_page2_pg1_repeat1.md
claude_filename = f"{pdf_base_name}_page{page_from_name}_pg1_repeat1.md"
# Build the full path to the claude_original file
claude_file_path = os.path.join(self.claude_original_folder, pdf_dir, claude_filename)
if os.path.exists(claude_file_path):
try:
with open(claude_file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.warning(f"Failed to read claude_original file {claude_file_path}: {e}")
else:
logger.debug(f"Claude original file not found: {claude_file_path}")
return None
def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]: def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]:
"""Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF.""" """Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF."""
jsonl_files = sorted(glob.glob(os.path.join(self.bench_data_folder, "*.jsonl"))) jsonl_files = sorted(glob.glob(os.path.join(self.bench_data_folder, "*.jsonl")))
@ -97,13 +144,15 @@ class OlmOCRBenchDataset(Dataset):
if pdf_page_key not in pdf_data: if pdf_page_key not in pdf_data:
# First time seeing this PDF+page # First time seeing this PDF+page
pdf_path = os.path.join(self.pdf_folder, pdf_name) pdf_path = os.path.join(self.pdf_folder, pdf_name)
claude_original = self._load_claude_original(pdf_name, page)
pdf_data[pdf_page_key] = { pdf_data[pdf_page_key] = {
"pdf_path": pdf_path, "pdf_path": pdf_path,
"pdf_name": pdf_name, "pdf_name": pdf_name,
"page": page, "page": page,
"jsonl_file": jsonl_file, "jsonl_file": jsonl_file,
"test_ids": [test_id], "test_ids": [test_id],
"entries": [entry] "entries": [entry],
"claude_original": claude_original
} }
else: else:
# Add test case to existing PDF+page # Add test case to existing PDF+page
@ -167,6 +216,7 @@ class OlmOCRBenchDataset(Dataset):
"jsonl_file": jsonl_file, "jsonl_file": jsonl_file,
"test_ids": test_ids, "test_ids": test_ids,
"image": image, # Include the PIL image for processing later "image": image, # Include the PIL image for processing later
"claude_original": sample.get("claude_original"), # Include claude_original if available
} }
except Exception as e: except Exception as e:
@ -276,6 +326,61 @@ def evaluate_single_completion(args: Tuple[int, Any, str, str, List[str]]) -> Tu
return i, None return i, None
def bench_edit_distance_reward(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]], **kwargs):
"""
Reward function based on edit distance similarity to claude_original files.
Calculates the normalized edit distance between each completion and its corresponding
claude_original reference. Returns 1.0 for perfect match, lower for more distance.
Args:
prompts: List of prompts
completions: List of generated completions (model outputs)
claude_original: List of claude_original reference texts (one per completion)
**kwargs: Additional arguments
Returns:
List of reward scores between 0 and 1, where 1.0 is perfect match
"""
logger.info(f"Running bench edit distance reward function for {len(completions)} completions")
rewards = []
for i, completion in enumerate(completions):
# Extract text from completion
if isinstance(completion, list):
comp_text = completion[0]["content"] if completion else ""
elif isinstance(completion, str):
comp_text = completion
else:
comp_text = ""
# Get the corresponding claude_original reference
reference = claude_original[i] if i < len(claude_original) else None
if reference is None:
logger.warning(f"No claude_original reference for completion {i}")
rewards.append(0.0)
continue
# Calculate edit distance
dist = distance.Levenshtein.distance(comp_text, reference)
# Calculate maximum possible distance (length of longer string)
max_dist = max(len(comp_text), len(reference))
# Calculate similarity (1.0 = perfect match, 0.0 = completely different)
if max_dist == 0:
similarity = 1.0 # Both empty strings
else:
similarity = 1.0 - (dist / max_dist)
rewards.append(max(0.0, similarity)) # Ensure non-negative
logger.info(f"Bench edit distance rewards range: [{min(rewards) if rewards else 0:.3f}, {max(rewards) if rewards else 0:.3f}]")
return rewards
def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs): def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
""" """
Reward function based on edit distance to the medoid completion. Reward function based on edit distance to the medoid completion.
@ -513,15 +618,27 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--reward_bench", "--reward_bench",
action="store_true", nargs='?',
default=False, const=1.0,
help="Use bench-based reward function (test pass rate)" type=float,
default=None,
help="Use bench-based reward function with optional weight (default: 1.0)"
) )
parser.add_argument( parser.add_argument(
"--reward_medoid", "--reward_medoid",
action="store_true", nargs='?',
default=False, const=1.0,
help="Use medoid-based reward function (edit distance similarity)" type=float,
default=None,
help="Use medoid-based reward function with optional weight (default: 1.0)"
)
parser.add_argument(
"--reward_bench_edit_distance",
nargs='?',
const=1.0,
type=float,
default=None,
help="Use bench edit distance reward with optional weight (default: 1.0)"
) )
args = parser.parse_args() args = parser.parse_args()
@ -632,22 +749,40 @@ def main():
log_completions=True, log_completions=True,
) )
# Build list of reward functions based on command-line arguments # Build list of reward functions and weights based on command-line arguments
reward_funcs = [] reward_funcs = []
reward_weights = []
reward_names = []
if args.reward_bench: if args.reward_bench is not None:
reward_funcs.append(olmocr_bench_reward) reward_funcs.append(olmocr_bench_reward)
logger.info("Added bench-based reward function") reward_weights.append(args.reward_bench)
reward_names.append("bench")
logger.info(f"Added bench-based reward function with weight {args.reward_bench}")
if args.reward_medoid: if args.reward_medoid is not None:
reward_funcs.append(medoid_reward) reward_funcs.append(medoid_reward)
logger.info("Added medoid-based reward function") reward_weights.append(args.reward_medoid)
reward_names.append("medoid")
logger.info(f"Added medoid-based reward function with weight {args.reward_medoid}")
if args.reward_bench_edit_distance is not None:
reward_funcs.append(bench_edit_distance_reward)
reward_weights.append(args.reward_bench_edit_distance)
reward_names.append("bench_edit_distance")
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
if not reward_funcs: if not reward_funcs:
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid") logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance")
return return
logger.info(f"Using {len(reward_funcs)} reward function(s)") # Log summary of reward configuration
logger.info(f"\n" + "="*50)
logger.info(f"Reward Configuration:")
logger.info(f"Using {len(reward_funcs)} reward function(s):")
for name, weight in zip(reward_names, reward_weights):
logger.info(f" - {name}: weight={weight}")
logger.info("="*50 + "\n")
# Initialize GRPO trainer # Initialize GRPO trainer
logger.info("Initializing GRPO trainer") logger.info("Initializing GRPO trainer")
@ -658,6 +793,7 @@ def main():
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
reward_funcs=reward_funcs, reward_funcs=reward_funcs,
reward_weights=reward_weights,
) )
# Start training # Start training