Bf16 only

This commit is contained in:
Jake Poznanski 2025-06-30 17:25:53 +00:00
parent 44dd966850
commit bde6f2955e
2 changed files with 1 additions and 8 deletions

View File

@ -147,11 +147,6 @@ class TrainingConfig:
gradient_checkpointing: bool = True
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})
# Mixed precision
fp16: bool = False
bf16: bool = True
tf32: bool = True # Enable TF32 on Ampere GPUs
# Evaluation and checkpointing
evaluation_strategy: str = "steps"
eval_steps: int = 500

View File

@ -178,9 +178,7 @@ def main():
adam_epsilon=config.training.adam_epsilon,
weight_decay=config.training.weight_decay,
max_grad_norm=config.training.max_grad_norm,
fp16=config.training.fp16,
bf16=config.training.bf16,
tf32=config.training.tf32,
bf16=True, # We're sticking with this known good reduced precision option
eval_strategy=config.training.evaluation_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,