Adding muon support

This commit is contained in:
Jake Poznanski 2025-07-25 17:29:12 +00:00
parent 6d5711fa3e
commit 0acad6faf6
3 changed files with 365 additions and 15 deletions

View File

@ -165,13 +165,19 @@ class TrainingConfig:
lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
# Optimization
optim: str = "adamw_torch"
optim: str = "adamw_torch" # "adamw_torch", "muon"
adam_beta1: float = 0.9
adam_beta2: float = 0.999
adam_epsilon: float = 1e-8
weight_decay: float = 0.01
max_grad_norm: float = 1.0
# Muon optimizer specific settings
muon_momentum: float = 0.95
muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters
# Gradient checkpointing
gradient_checkpointing: bool = False
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})

288
olmocr/train/muon.py Normal file
View File

@ -0,0 +1,288 @@
# FROM: https://raw.githubusercontent.com/KellerJordan/Muon/refs/heads/master/muon.py
import torch
import torch.distributed as dist
def zeropower_via_newtonschulz5(G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
return update
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
advantage that it can be stably run in bfloat16 on the GPU.
Muon should only be used for hidden weight layers. The input embedding, final output layer,
and any internal gains or biases should be optimized using a standard method such as AdamW.
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
collapsing their last 3 dimensions.
Arguments:
lr: The learning rate, in units of spectral norm per update.
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
return loss
class SingleDeviceMuon(torch.optim.Optimizer):
"""
Muon variant for usage in non-distributed settings.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
return loss
def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
return buf1c / (buf2c.sqrt() + eps)
class MuonWithAuxAdam(torch.optim.Optimizer):
"""
Distributed Muon variant that can be used for all parameters in the network, since it runs an
internal AdamW for the parameters that are not compatible with Muon. The user must manually
specify which parameters shall be optimized with Muon and which with Adam by passing in a
list of param_groups with the `use_muon` flag set.
The point of this class is to allow the user to have a single optimizer in their code, rather
than having both a Muon and an Adam which each need to be stepped.
You can see an example usage below:
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
```
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [model.lm_head.weight]
from muon import MuonWithAuxAdam
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
param_groups = [*adam_groups, muon_group]
optimizer = MuonWithAuxAdam(param_groups)
```
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if group["use_muon"]:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if group["use_muon"]:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss

View File

@ -15,6 +15,7 @@ from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast
import wandb
from tqdm import tqdm
from transformers import (
AutoProcessor,
@ -26,6 +27,7 @@ from transformers import (
from typing import Optional, Dict, Any
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
# Configure logging
logging.basicConfig(
@ -321,24 +323,56 @@ def main():
model.to(device)
# Set up optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if config.training.optim == "adamw_torch":
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=float(config.training.learning_rate),
betas=(config.training.adam_beta1, config.training.adam_beta2),
eps=config.training.adam_epsilon,
)
elif config.training.optim == "muon":
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
# Create Adam groups with different learning rates
adam_groups = [
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
]
# Add Adam hyperparameters to groups
for g in adam_groups:
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
g["eps"] = config.training.adam_epsilon
g["weight_decay"] = config.training.weight_decay
# Create Muon group
muon_group = dict(
params=hidden_matrix_params,
lr=float(config.training.learning_rate),
momentum=config.training.muon_momentum,
weight_decay=config.training.weight_decay,
use_muon=True
)
# Combine all groups
param_groups = [*adam_groups, muon_group]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
else:
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
@ -428,6 +462,9 @@ def main():
epoch_iterator = iter(train_dataloader)
break
# Create progress bar
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
while samples_seen < max_train_samples and global_step < max_train_steps:
try:
batch = next(epoch_iterator)
@ -449,6 +486,9 @@ def main():
num_losses_accumulated += 1
samples_seen += config.training.per_device_train_batch_size
# Update progress bar
pbar.update(config.training.per_device_train_batch_size)
# Check if we should do a gradient update
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
# Clip gradients
@ -462,6 +502,16 @@ def main():
global_step += 1
current_epoch = samples_seen / len(train_dataset)
# Update progress bar with current stats
current_lr = lr_scheduler.get_last_lr()[0]
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
pbar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{current_lr:.2e}',
'epoch': f'{current_epoch:.2f}',
'step': global_step
})
# Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
@ -506,9 +556,15 @@ def main():
if samples_seen >= max_train_samples or global_step >= max_train_steps:
break
# Save the final model
logger.info("Saving final model...")
model.save_pretrained(full_output_dir)
# Close progress bar
pbar.close()
# Save the final checkpoint with step number
logger.info(f"Saving final checkpoint at step {global_step}...")
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
# Log final training state
final_epoch = samples_seen / len(train_dataset)