This commit is contained in:
Jake Poznanski 2025-08-21 17:01:40 +00:00
parent d08068218c
commit a443b89854

View File

@ -170,32 +170,6 @@ class OlmOCRDataset(Dataset):
# Return None if processing fails
return None
def collate_fn(batch):
"""Custom collate function to handle the new batch format with prompts and metadata."""
# Filter out None values
batch = [item for item in batch if item is not None]
if not batch:
return None
# Collect all components
prompts = [item["prompt"] for item in batch]
images = [item["image"] for item in batch]
pdf_paths = [item["pdf_path"] for item in batch]
jsonl_files = [item["jsonl_file"] for item in batch]
test_ids = [item["test_ids"] for item in batch]
# Return batch with all required information
return {
"prompts": prompts,
"images": images,
"pdf_paths": pdf_paths,
"jsonl_files": jsonl_files,
"test_ids": test_ids,
}
def simple_length_reward(completions: List[str], **kwargs) -> List[float]:
"""
Simple reward function that rewards completions close to 100 tokens.
@ -420,8 +394,7 @@ def main():
processing_class=processor,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
reward_function=simple_length_reward,
data_collator=collate_fn,
reward_funcs=simple_length_reward,
)
# Start training