From 214c44df36046a81583c3a94a9cf482894e61398 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Fri, 27 Jun 2025 21:16:22 +0000 Subject: [PATCH] Reporting to wandb, better eval dataset loading --- olmocr/train/config.py | 2 +- olmocr/train/configs/example_config.yaml | 2 ++ olmocr/train/train.py | 18 +++++++++++------- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/olmocr/train/config.py b/olmocr/train/config.py index 8a8c203..fa6bba0 100644 --- a/olmocr/train/config.py +++ b/olmocr/train/config.py @@ -173,7 +173,7 @@ class TrainingConfig: logging_strategy: str = "steps" logging_steps: int = 10 logging_first_step: bool = True - report_to: List[str] = field(default_factory=lambda: ["tensorboard"]) + report_to: List[str] = field(default_factory=lambda: ["wandb"]) # Other training settings seed: int = 42 diff --git a/olmocr/train/configs/example_config.yaml b/olmocr/train/configs/example_config.yaml index 10c6d42..bc44f01 100644 --- a/olmocr/train/configs/example_config.yaml +++ b/olmocr/train/configs/example_config.yaml @@ -82,4 +82,6 @@ training: metric_for_best_model: eval_loss greater_is_better: false + report_to: + - wandb \ No newline at end of file diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 522ee48..7433b21 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -136,21 +136,24 @@ def main(): # Create evaluation datasets logger.info("Creating evaluation datasets...") - eval_datasets = [] + eval_datasets = {} for i, dataset_cfg in enumerate(config.dataset.eval): root_dir = dataset_cfg['root_dir'] pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) - logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}") + # Use dataset name if provided, otherwise use root_dir as name + dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}") + + logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}") dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) logger.info(f"Found {len(dataset)} samples") if len(dataset) > 0: - eval_datasets.append(dataset) + eval_datasets[dataset_name] = dataset - # Combine all evaluation datasets - eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0] - logger.info(f"Total evaluation samples: {len(eval_dataset)}") + # Log total evaluation samples across all datasets + total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) + logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") # Set up training arguments training_args = TrainingArguments( @@ -194,6 +197,7 @@ def main(): dataloader_drop_last=config.training.dataloader_drop_last, dataloader_num_workers=config.training.dataloader_num_workers, remove_unused_columns=config.training.remove_unused_columns, + eval_on_start=True, run_name=config.run_name, ) @@ -213,7 +217,7 @@ def main(): model=model, args=training_args, train_dataset=train_dataset, - eval_dataset=eval_dataset, + eval_dataset=eval_datasets, data_collator=create_data_collator(), callbacks=callbacks, )