diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index e1c2818..0940e33 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -170,9 +170,10 @@ class OlmOCRDataset(Dataset): # Return None if processing fails return None -def simple_length_reward(completions_ids, **kwargs): +def simple_length_reward(completion_ids, **kwargs): """Reward function that assigns higher scores to longer completions (in terms of token count).""" - return [float(len(ids)) for ids in completions_ids] + logger.info(f"Reward function called {kwargs}") + return [float(len(ids)) for ids in completion_ids] def main():