mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 01:02:26 +00:00
Adding additional rewards and weights
This commit is contained in:
parent
9c520498dd
commit
c86b413d3e
@ -62,11 +62,58 @@ class OlmOCRBenchDataset(Dataset):
|
||||
if not os.path.exists(self.pdf_folder):
|
||||
raise ValueError(f"PDFs folder not found at {self.pdf_folder}")
|
||||
|
||||
# Set claude_original folder path
|
||||
self.claude_original_folder = os.path.join(bench_data_folder, "claude_original")
|
||||
if os.path.exists(self.claude_original_folder):
|
||||
logger.info(f"Found claude_original folder at {self.claude_original_folder}")
|
||||
else:
|
||||
logger.warning(f"No claude_original folder found at {self.claude_original_folder}")
|
||||
|
||||
# Load unique PDFs from JSONL files
|
||||
self.samples = self._load_unique_pdfs_from_jsonl()
|
||||
|
||||
logger.info(f"Created dataset with {len(self.samples)} unique PDF samples")
|
||||
|
||||
def _load_claude_original(self, pdf_name: str, page: int) -> Optional[str]:
|
||||
"""Load the claude_original markdown file for a given PDF and page."""
|
||||
if not os.path.exists(self.claude_original_folder):
|
||||
return None
|
||||
|
||||
# Extract the base PDF name and construct the expected filename
|
||||
# pdf_name like "s2pdf/pdf_00017_page2.pdf" -> construct the markdown filename
|
||||
pdf_base = os.path.basename(pdf_name).replace(".pdf", "")
|
||||
|
||||
# Handle case where page is already in the filename
|
||||
if "_page" in pdf_base:
|
||||
pdf_base_parts = pdf_base.split("_page")
|
||||
pdf_base_name = pdf_base_parts[0]
|
||||
# Use the page from the filename if it exists
|
||||
page_from_name = int(pdf_base_parts[1]) if len(pdf_base_parts) > 1 and pdf_base_parts[1].isdigit() else page
|
||||
else:
|
||||
pdf_base_name = pdf_base
|
||||
page_from_name = page
|
||||
|
||||
# Extract folder structure from pdf_name (e.g., "s2pdf/" or "arxiv_math/")
|
||||
pdf_dir = os.path.dirname(pdf_name)
|
||||
|
||||
# Construct the expected claude_original filename
|
||||
# Format: pdf_00017_page2_pg1_repeat1.md
|
||||
claude_filename = f"{pdf_base_name}_page{page_from_name}_pg1_repeat1.md"
|
||||
|
||||
# Build the full path to the claude_original file
|
||||
claude_file_path = os.path.join(self.claude_original_folder, pdf_dir, claude_filename)
|
||||
|
||||
if os.path.exists(claude_file_path):
|
||||
try:
|
||||
with open(claude_file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read claude_original file {claude_file_path}: {e}")
|
||||
else:
|
||||
logger.debug(f"Claude original file not found: {claude_file_path}")
|
||||
|
||||
return None
|
||||
|
||||
def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]:
|
||||
"""Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF."""
|
||||
jsonl_files = sorted(glob.glob(os.path.join(self.bench_data_folder, "*.jsonl")))
|
||||
@ -97,13 +144,15 @@ class OlmOCRBenchDataset(Dataset):
|
||||
if pdf_page_key not in pdf_data:
|
||||
# First time seeing this PDF+page
|
||||
pdf_path = os.path.join(self.pdf_folder, pdf_name)
|
||||
claude_original = self._load_claude_original(pdf_name, page)
|
||||
pdf_data[pdf_page_key] = {
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page": page,
|
||||
"jsonl_file": jsonl_file,
|
||||
"test_ids": [test_id],
|
||||
"entries": [entry]
|
||||
"entries": [entry],
|
||||
"claude_original": claude_original
|
||||
}
|
||||
else:
|
||||
# Add test case to existing PDF+page
|
||||
@ -167,6 +216,7 @@ class OlmOCRBenchDataset(Dataset):
|
||||
"jsonl_file": jsonl_file,
|
||||
"test_ids": test_ids,
|
||||
"image": image, # Include the PIL image for processing later
|
||||
"claude_original": sample.get("claude_original"), # Include claude_original if available
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -276,6 +326,61 @@ def evaluate_single_completion(args: Tuple[int, Any, str, str, List[str]]) -> Tu
|
||||
return i, None
|
||||
|
||||
|
||||
def bench_edit_distance_reward(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]], **kwargs):
|
||||
"""
|
||||
Reward function based on edit distance similarity to claude_original files.
|
||||
|
||||
Calculates the normalized edit distance between each completion and its corresponding
|
||||
claude_original reference. Returns 1.0 for perfect match, lower for more distance.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts
|
||||
completions: List of generated completions (model outputs)
|
||||
claude_original: List of claude_original reference texts (one per completion)
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
List of reward scores between 0 and 1, where 1.0 is perfect match
|
||||
"""
|
||||
logger.info(f"Running bench edit distance reward function for {len(completions)} completions")
|
||||
|
||||
rewards = []
|
||||
|
||||
for i, completion in enumerate(completions):
|
||||
# Extract text from completion
|
||||
if isinstance(completion, list):
|
||||
comp_text = completion[0]["content"] if completion else ""
|
||||
elif isinstance(completion, str):
|
||||
comp_text = completion
|
||||
else:
|
||||
comp_text = ""
|
||||
|
||||
# Get the corresponding claude_original reference
|
||||
reference = claude_original[i] if i < len(claude_original) else None
|
||||
|
||||
if reference is None:
|
||||
logger.warning(f"No claude_original reference for completion {i}")
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Calculate edit distance
|
||||
dist = distance.Levenshtein.distance(comp_text, reference)
|
||||
|
||||
# Calculate maximum possible distance (length of longer string)
|
||||
max_dist = max(len(comp_text), len(reference))
|
||||
|
||||
# Calculate similarity (1.0 = perfect match, 0.0 = completely different)
|
||||
if max_dist == 0:
|
||||
similarity = 1.0 # Both empty strings
|
||||
else:
|
||||
similarity = 1.0 - (dist / max_dist)
|
||||
|
||||
rewards.append(max(0.0, similarity)) # Ensure non-negative
|
||||
|
||||
logger.info(f"Bench edit distance rewards range: [{min(rewards) if rewards else 0:.3f}, {max(rewards) if rewards else 0:.3f}]")
|
||||
return rewards
|
||||
|
||||
|
||||
def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||
"""
|
||||
Reward function based on edit distance to the medoid completion.
|
||||
@ -513,15 +618,27 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_bench",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use bench-based reward function (test pass rate)"
|
||||
nargs='?',
|
||||
const=1.0,
|
||||
type=float,
|
||||
default=None,
|
||||
help="Use bench-based reward function with optional weight (default: 1.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_medoid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use medoid-based reward function (edit distance similarity)"
|
||||
nargs='?',
|
||||
const=1.0,
|
||||
type=float,
|
||||
default=None,
|
||||
help="Use medoid-based reward function with optional weight (default: 1.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_bench_edit_distance",
|
||||
nargs='?',
|
||||
const=1.0,
|
||||
type=float,
|
||||
default=None,
|
||||
help="Use bench edit distance reward with optional weight (default: 1.0)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -632,22 +749,40 @@ def main():
|
||||
log_completions=True,
|
||||
)
|
||||
|
||||
# Build list of reward functions based on command-line arguments
|
||||
# Build list of reward functions and weights based on command-line arguments
|
||||
reward_funcs = []
|
||||
reward_weights = []
|
||||
reward_names = []
|
||||
|
||||
if args.reward_bench:
|
||||
if args.reward_bench is not None:
|
||||
reward_funcs.append(olmocr_bench_reward)
|
||||
logger.info("Added bench-based reward function")
|
||||
reward_weights.append(args.reward_bench)
|
||||
reward_names.append("bench")
|
||||
logger.info(f"Added bench-based reward function with weight {args.reward_bench}")
|
||||
|
||||
if args.reward_medoid:
|
||||
if args.reward_medoid is not None:
|
||||
reward_funcs.append(medoid_reward)
|
||||
logger.info("Added medoid-based reward function")
|
||||
reward_weights.append(args.reward_medoid)
|
||||
reward_names.append("medoid")
|
||||
logger.info(f"Added medoid-based reward function with weight {args.reward_medoid}")
|
||||
|
||||
if args.reward_bench_edit_distance is not None:
|
||||
reward_funcs.append(bench_edit_distance_reward)
|
||||
reward_weights.append(args.reward_bench_edit_distance)
|
||||
reward_names.append("bench_edit_distance")
|
||||
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
|
||||
|
||||
if not reward_funcs:
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid")
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance")
|
||||
return
|
||||
|
||||
logger.info(f"Using {len(reward_funcs)} reward function(s)")
|
||||
# Log summary of reward configuration
|
||||
logger.info(f"\n" + "="*50)
|
||||
logger.info(f"Reward Configuration:")
|
||||
logger.info(f"Using {len(reward_funcs)} reward function(s):")
|
||||
for name, weight in zip(reward_names, reward_weights):
|
||||
logger.info(f" - {name}: weight={weight}")
|
||||
logger.info("="*50 + "\n")
|
||||
|
||||
# Initialize GRPO trainer
|
||||
logger.info("Initializing GRPO trainer")
|
||||
@ -658,6 +793,7 @@ def main():
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
reward_funcs=reward_funcs,
|
||||
reward_weights=reward_weights,
|
||||
)
|
||||
|
||||
# Start training
|
||||
|
Loading…
x
Reference in New Issue
Block a user