mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 00:32:45 +00:00
Reporting to wandb, better eval dataset loading
This commit is contained in:
parent
600d967fe6
commit
214c44df36
@ -173,7 +173,7 @@ class TrainingConfig:
|
||||
logging_strategy: str = "steps"
|
||||
logging_steps: int = 10
|
||||
logging_first_step: bool = True
|
||||
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
|
||||
report_to: List[str] = field(default_factory=lambda: ["wandb"])
|
||||
|
||||
# Other training settings
|
||||
seed: int = 42
|
||||
|
@ -82,4 +82,6 @@ training:
|
||||
metric_for_best_model: eval_loss
|
||||
greater_is_better: false
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
|
@ -136,21 +136,24 @@ def main():
|
||||
|
||||
# Create evaluation datasets
|
||||
logger.info("Creating evaluation datasets...")
|
||||
eval_datasets = []
|
||||
eval_datasets = {}
|
||||
for i, dataset_cfg in enumerate(config.dataset.eval):
|
||||
root_dir = dataset_cfg['root_dir']
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||
|
||||
logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}")
|
||||
# Use dataset name if provided, otherwise use root_dir as name
|
||||
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}")
|
||||
|
||||
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
|
||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||
logger.info(f"Found {len(dataset)} samples")
|
||||
|
||||
if len(dataset) > 0:
|
||||
eval_datasets.append(dataset)
|
||||
eval_datasets[dataset_name] = dataset
|
||||
|
||||
# Combine all evaluation datasets
|
||||
eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0]
|
||||
logger.info(f"Total evaluation samples: {len(eval_dataset)}")
|
||||
# Log total evaluation samples across all datasets
|
||||
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
||||
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
||||
|
||||
# Set up training arguments
|
||||
training_args = TrainingArguments(
|
||||
@ -194,6 +197,7 @@ def main():
|
||||
dataloader_drop_last=config.training.dataloader_drop_last,
|
||||
dataloader_num_workers=config.training.dataloader_num_workers,
|
||||
remove_unused_columns=config.training.remove_unused_columns,
|
||||
eval_on_start=True,
|
||||
run_name=config.run_name,
|
||||
)
|
||||
|
||||
@ -213,7 +217,7 @@ def main():
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_dataset=eval_datasets,
|
||||
data_collator=create_data_collator(),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user