mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 17:22:13 +00:00
Adding additional rewards and weights
This commit is contained in:
parent
9c520498dd
commit
c86b413d3e
@ -62,11 +62,58 @@ class OlmOCRBenchDataset(Dataset):
|
|||||||
if not os.path.exists(self.pdf_folder):
|
if not os.path.exists(self.pdf_folder):
|
||||||
raise ValueError(f"PDFs folder not found at {self.pdf_folder}")
|
raise ValueError(f"PDFs folder not found at {self.pdf_folder}")
|
||||||
|
|
||||||
|
# Set claude_original folder path
|
||||||
|
self.claude_original_folder = os.path.join(bench_data_folder, "claude_original")
|
||||||
|
if os.path.exists(self.claude_original_folder):
|
||||||
|
logger.info(f"Found claude_original folder at {self.claude_original_folder}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"No claude_original folder found at {self.claude_original_folder}")
|
||||||
|
|
||||||
# Load unique PDFs from JSONL files
|
# Load unique PDFs from JSONL files
|
||||||
self.samples = self._load_unique_pdfs_from_jsonl()
|
self.samples = self._load_unique_pdfs_from_jsonl()
|
||||||
|
|
||||||
logger.info(f"Created dataset with {len(self.samples)} unique PDF samples")
|
logger.info(f"Created dataset with {len(self.samples)} unique PDF samples")
|
||||||
|
|
||||||
|
def _load_claude_original(self, pdf_name: str, page: int) -> Optional[str]:
|
||||||
|
"""Load the claude_original markdown file for a given PDF and page."""
|
||||||
|
if not os.path.exists(self.claude_original_folder):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract the base PDF name and construct the expected filename
|
||||||
|
# pdf_name like "s2pdf/pdf_00017_page2.pdf" -> construct the markdown filename
|
||||||
|
pdf_base = os.path.basename(pdf_name).replace(".pdf", "")
|
||||||
|
|
||||||
|
# Handle case where page is already in the filename
|
||||||
|
if "_page" in pdf_base:
|
||||||
|
pdf_base_parts = pdf_base.split("_page")
|
||||||
|
pdf_base_name = pdf_base_parts[0]
|
||||||
|
# Use the page from the filename if it exists
|
||||||
|
page_from_name = int(pdf_base_parts[1]) if len(pdf_base_parts) > 1 and pdf_base_parts[1].isdigit() else page
|
||||||
|
else:
|
||||||
|
pdf_base_name = pdf_base
|
||||||
|
page_from_name = page
|
||||||
|
|
||||||
|
# Extract folder structure from pdf_name (e.g., "s2pdf/" or "arxiv_math/")
|
||||||
|
pdf_dir = os.path.dirname(pdf_name)
|
||||||
|
|
||||||
|
# Construct the expected claude_original filename
|
||||||
|
# Format: pdf_00017_page2_pg1_repeat1.md
|
||||||
|
claude_filename = f"{pdf_base_name}_page{page_from_name}_pg1_repeat1.md"
|
||||||
|
|
||||||
|
# Build the full path to the claude_original file
|
||||||
|
claude_file_path = os.path.join(self.claude_original_folder, pdf_dir, claude_filename)
|
||||||
|
|
||||||
|
if os.path.exists(claude_file_path):
|
||||||
|
try:
|
||||||
|
with open(claude_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read claude_original file {claude_file_path}: {e}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Claude original file not found: {claude_file_path}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]:
|
def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]:
|
||||||
"""Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF."""
|
"""Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF."""
|
||||||
jsonl_files = sorted(glob.glob(os.path.join(self.bench_data_folder, "*.jsonl")))
|
jsonl_files = sorted(glob.glob(os.path.join(self.bench_data_folder, "*.jsonl")))
|
||||||
@ -97,13 +144,15 @@ class OlmOCRBenchDataset(Dataset):
|
|||||||
if pdf_page_key not in pdf_data:
|
if pdf_page_key not in pdf_data:
|
||||||
# First time seeing this PDF+page
|
# First time seeing this PDF+page
|
||||||
pdf_path = os.path.join(self.pdf_folder, pdf_name)
|
pdf_path = os.path.join(self.pdf_folder, pdf_name)
|
||||||
|
claude_original = self._load_claude_original(pdf_name, page)
|
||||||
pdf_data[pdf_page_key] = {
|
pdf_data[pdf_page_key] = {
|
||||||
"pdf_path": pdf_path,
|
"pdf_path": pdf_path,
|
||||||
"pdf_name": pdf_name,
|
"pdf_name": pdf_name,
|
||||||
"page": page,
|
"page": page,
|
||||||
"jsonl_file": jsonl_file,
|
"jsonl_file": jsonl_file,
|
||||||
"test_ids": [test_id],
|
"test_ids": [test_id],
|
||||||
"entries": [entry]
|
"entries": [entry],
|
||||||
|
"claude_original": claude_original
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Add test case to existing PDF+page
|
# Add test case to existing PDF+page
|
||||||
@ -167,6 +216,7 @@ class OlmOCRBenchDataset(Dataset):
|
|||||||
"jsonl_file": jsonl_file,
|
"jsonl_file": jsonl_file,
|
||||||
"test_ids": test_ids,
|
"test_ids": test_ids,
|
||||||
"image": image, # Include the PIL image for processing later
|
"image": image, # Include the PIL image for processing later
|
||||||
|
"claude_original": sample.get("claude_original"), # Include claude_original if available
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -276,6 +326,61 @@ def evaluate_single_completion(args: Tuple[int, Any, str, str, List[str]]) -> Tu
|
|||||||
return i, None
|
return i, None
|
||||||
|
|
||||||
|
|
||||||
|
def bench_edit_distance_reward(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]], **kwargs):
|
||||||
|
"""
|
||||||
|
Reward function based on edit distance similarity to claude_original files.
|
||||||
|
|
||||||
|
Calculates the normalized edit distance between each completion and its corresponding
|
||||||
|
claude_original reference. Returns 1.0 for perfect match, lower for more distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompts
|
||||||
|
completions: List of generated completions (model outputs)
|
||||||
|
claude_original: List of claude_original reference texts (one per completion)
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of reward scores between 0 and 1, where 1.0 is perfect match
|
||||||
|
"""
|
||||||
|
logger.info(f"Running bench edit distance reward function for {len(completions)} completions")
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
for i, completion in enumerate(completions):
|
||||||
|
# Extract text from completion
|
||||||
|
if isinstance(completion, list):
|
||||||
|
comp_text = completion[0]["content"] if completion else ""
|
||||||
|
elif isinstance(completion, str):
|
||||||
|
comp_text = completion
|
||||||
|
else:
|
||||||
|
comp_text = ""
|
||||||
|
|
||||||
|
# Get the corresponding claude_original reference
|
||||||
|
reference = claude_original[i] if i < len(claude_original) else None
|
||||||
|
|
||||||
|
if reference is None:
|
||||||
|
logger.warning(f"No claude_original reference for completion {i}")
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate edit distance
|
||||||
|
dist = distance.Levenshtein.distance(comp_text, reference)
|
||||||
|
|
||||||
|
# Calculate maximum possible distance (length of longer string)
|
||||||
|
max_dist = max(len(comp_text), len(reference))
|
||||||
|
|
||||||
|
# Calculate similarity (1.0 = perfect match, 0.0 = completely different)
|
||||||
|
if max_dist == 0:
|
||||||
|
similarity = 1.0 # Both empty strings
|
||||||
|
else:
|
||||||
|
similarity = 1.0 - (dist / max_dist)
|
||||||
|
|
||||||
|
rewards.append(max(0.0, similarity)) # Ensure non-negative
|
||||||
|
|
||||||
|
logger.info(f"Bench edit distance rewards range: [{min(rewards) if rewards else 0:.3f}, {max(rewards) if rewards else 0:.3f}]")
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Reward function based on edit distance to the medoid completion.
|
Reward function based on edit distance to the medoid completion.
|
||||||
@ -513,15 +618,27 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reward_bench",
|
"--reward_bench",
|
||||||
action="store_true",
|
nargs='?',
|
||||||
default=False,
|
const=1.0,
|
||||||
help="Use bench-based reward function (test pass rate)"
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Use bench-based reward function with optional weight (default: 1.0)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reward_medoid",
|
"--reward_medoid",
|
||||||
action="store_true",
|
nargs='?',
|
||||||
default=False,
|
const=1.0,
|
||||||
help="Use medoid-based reward function (edit distance similarity)"
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Use medoid-based reward function with optional weight (default: 1.0)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward_bench_edit_distance",
|
||||||
|
nargs='?',
|
||||||
|
const=1.0,
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Use bench edit distance reward with optional weight (default: 1.0)"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -632,22 +749,40 @@ def main():
|
|||||||
log_completions=True,
|
log_completions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build list of reward functions based on command-line arguments
|
# Build list of reward functions and weights based on command-line arguments
|
||||||
reward_funcs = []
|
reward_funcs = []
|
||||||
|
reward_weights = []
|
||||||
|
reward_names = []
|
||||||
|
|
||||||
if args.reward_bench:
|
if args.reward_bench is not None:
|
||||||
reward_funcs.append(olmocr_bench_reward)
|
reward_funcs.append(olmocr_bench_reward)
|
||||||
logger.info("Added bench-based reward function")
|
reward_weights.append(args.reward_bench)
|
||||||
|
reward_names.append("bench")
|
||||||
|
logger.info(f"Added bench-based reward function with weight {args.reward_bench}")
|
||||||
|
|
||||||
if args.reward_medoid:
|
if args.reward_medoid is not None:
|
||||||
reward_funcs.append(medoid_reward)
|
reward_funcs.append(medoid_reward)
|
||||||
logger.info("Added medoid-based reward function")
|
reward_weights.append(args.reward_medoid)
|
||||||
|
reward_names.append("medoid")
|
||||||
|
logger.info(f"Added medoid-based reward function with weight {args.reward_medoid}")
|
||||||
|
|
||||||
|
if args.reward_bench_edit_distance is not None:
|
||||||
|
reward_funcs.append(bench_edit_distance_reward)
|
||||||
|
reward_weights.append(args.reward_bench_edit_distance)
|
||||||
|
reward_names.append("bench_edit_distance")
|
||||||
|
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
|
||||||
|
|
||||||
if not reward_funcs:
|
if not reward_funcs:
|
||||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid")
|
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Using {len(reward_funcs)} reward function(s)")
|
# Log summary of reward configuration
|
||||||
|
logger.info(f"\n" + "="*50)
|
||||||
|
logger.info(f"Reward Configuration:")
|
||||||
|
logger.info(f"Using {len(reward_funcs)} reward function(s):")
|
||||||
|
for name, weight in zip(reward_names, reward_weights):
|
||||||
|
logger.info(f" - {name}: weight={weight}")
|
||||||
|
logger.info("="*50 + "\n")
|
||||||
|
|
||||||
# Initialize GRPO trainer
|
# Initialize GRPO trainer
|
||||||
logger.info("Initializing GRPO trainer")
|
logger.info("Initializing GRPO trainer")
|
||||||
@ -658,6 +793,7 @@ def main():
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
reward_funcs=reward_funcs,
|
reward_funcs=reward_funcs,
|
||||||
|
reward_weights=reward_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
|
Loading…
x
Reference in New Issue
Block a user