Reporting to wandb, better eval dataset loading

This commit is contained in:
Jake Poznanski 2025-06-27 21:16:22 +00:00
parent 600d967fe6
commit 214c44df36
3 changed files with 14 additions and 8 deletions

View File

@ -173,7 +173,7 @@ class TrainingConfig:
logging_strategy: str = "steps" logging_strategy: str = "steps"
logging_steps: int = 10 logging_steps: int = 10
logging_first_step: bool = True logging_first_step: bool = True
report_to: List[str] = field(default_factory=lambda: ["tensorboard"]) report_to: List[str] = field(default_factory=lambda: ["wandb"])
# Other training settings # Other training settings
seed: int = 42 seed: int = 42

View File

@ -82,4 +82,6 @@ training:
metric_for_best_model: eval_loss metric_for_best_model: eval_loss
greater_is_better: false greater_is_better: false
report_to:
- wandb

View File

@ -136,21 +136,24 @@ def main():
# Create evaluation datasets # Create evaluation datasets
logger.info("Creating evaluation datasets...") logger.info("Creating evaluation datasets...")
eval_datasets = [] eval_datasets = {}
for i, dataset_cfg in enumerate(config.dataset.eval): for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}") # Use dataset name if provided, otherwise use root_dir as name
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}")
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples") logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0: if len(dataset) > 0:
eval_datasets.append(dataset) eval_datasets[dataset_name] = dataset
# Combine all evaluation datasets # Log total evaluation samples across all datasets
eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0] total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples: {len(eval_dataset)}") logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
# Set up training arguments # Set up training arguments
training_args = TrainingArguments( training_args = TrainingArguments(
@ -194,6 +197,7 @@ def main():
dataloader_drop_last=config.training.dataloader_drop_last, dataloader_drop_last=config.training.dataloader_drop_last,
dataloader_num_workers=config.training.dataloader_num_workers, dataloader_num_workers=config.training.dataloader_num_workers,
remove_unused_columns=config.training.remove_unused_columns, remove_unused_columns=config.training.remove_unused_columns,
eval_on_start=True,
run_name=config.run_name, run_name=config.run_name,
) )
@ -213,7 +217,7 @@ def main():
model=model, model=model,
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_datasets,
data_collator=create_data_collator(), data_collator=create_data_collator(),
callbacks=callbacks, callbacks=callbacks,
) )