Reporting to wandb, better eval dataset loading

This commit is contained in:
Jake Poznanski 2025-06-27 21:16:22 +00:00
parent 600d967fe6
commit 214c44df36
3 changed files with 14 additions and 8 deletions

View File

@ -173,7 +173,7 @@ class TrainingConfig:
logging_strategy: str = "steps"
logging_steps: int = 10
logging_first_step: bool = True
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
report_to: List[str] = field(default_factory=lambda: ["wandb"])
# Other training settings
seed: int = 42

View File

@ -82,4 +82,6 @@ training:
metric_for_best_model: eval_loss
greater_is_better: false
report_to:
- wandb

View File

@ -136,21 +136,24 @@ def main():
# Create evaluation datasets
logger.info("Creating evaluation datasets...")
eval_datasets = []
eval_datasets = {}
for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}")
# Use dataset name if provided, otherwise use root_dir as name
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}")
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
eval_datasets.append(dataset)
eval_datasets[dataset_name] = dataset
# Combine all evaluation datasets
eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0]
logger.info(f"Total evaluation samples: {len(eval_dataset)}")
# Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
# Set up training arguments
training_args = TrainingArguments(
@ -194,6 +197,7 @@ def main():
dataloader_drop_last=config.training.dataloader_drop_last,
dataloader_num_workers=config.training.dataloader_num_workers,
remove_unused_columns=config.training.remove_unused_columns,
eval_on_start=True,
run_name=config.run_name,
)
@ -213,7 +217,7 @@ def main():
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
eval_dataset=eval_datasets,
data_collator=create_data_collator(),
callbacks=callbacks,
)