mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
Adding muon support
This commit is contained in:
parent
6d5711fa3e
commit
0acad6faf6
@ -165,12 +165,18 @@ class TrainingConfig:
|
||||
lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Optimization
|
||||
optim: str = "adamw_torch"
|
||||
optim: str = "adamw_torch" # "adamw_torch", "muon"
|
||||
adam_beta1: float = 0.9
|
||||
adam_beta2: float = 0.999
|
||||
adam_epsilon: float = 1e-8
|
||||
weight_decay: float = 0.01
|
||||
max_grad_norm: float = 1.0
|
||||
|
||||
# Muon optimizer specific settings
|
||||
muon_momentum: float = 0.95
|
||||
muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters
|
||||
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
|
||||
muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters
|
||||
|
||||
# Gradient checkpointing
|
||||
gradient_checkpointing: bool = False
|
||||
|
288
olmocr/train/muon.py
Normal file
288
olmocr/train/muon.py
Normal file
@ -0,0 +1,288 @@
|
||||
# FROM: https://raw.githubusercontent.com/KellerJordan/Muon/refs/heads/master/muon.py
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def zeropower_via_newtonschulz5(G, steps: int):
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
||||
momentum.lerp_(grad, 1 - beta)
|
||||
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
update = update.view(len(update), -1)
|
||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
|
||||
return update
|
||||
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
|
||||
advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Muon should only be used for hidden weight layers. The input embedding, final output layer,
|
||||
and any internal gains or biases should be optimized using a standard method such as AdamW.
|
||||
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
|
||||
collapsing their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate, in units of spectral norm per update.
|
||||
weight_decay: The AdamW-style weight decay.
|
||||
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||
params = sorted(params, key=lambda x: x.size(), reverse=True)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon variant for usage in non-distributed settings.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||
buf1.lerp_(grad, 1 - betas[0])
|
||||
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||
buf1c = buf1 / (1 - betas[0]**step)
|
||||
buf2c = buf2 / (1 - betas[1]**step)
|
||||
return buf1c / (buf2c.sqrt() + eps)
|
||||
|
||||
|
||||
class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed Muon variant that can be used for all parameters in the network, since it runs an
|
||||
internal AdamW for the parameters that are not compatible with Muon. The user must manually
|
||||
specify which parameters shall be optimized with Muon and which with Adam by passing in a
|
||||
list of param_groups with the `use_muon` flag set.
|
||||
|
||||
The point of this class is to allow the user to have a single optimizer in their code, rather
|
||||
than having both a Muon and an Adam which each need to be stepped.
|
||||
|
||||
You can see an example usage below:
|
||||
|
||||
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
|
||||
```
|
||||
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [model.lm_head.weight]
|
||||
|
||||
from muon import MuonWithAuxAdam
|
||||
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
|
||||
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
|
||||
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
|
||||
param_groups = [*adam_groups, muon_group]
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
```
|
||||
"""
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
if group["use_muon"]:
|
||||
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 0.02)
|
||||
group["momentum"] = group.get("momentum", 0.95)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||
else:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 3e-4)
|
||||
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||
group["eps"] = group.get("eps", 1e-10)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||
super().__init__(param_groups, dict())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Non-distributed variant of MuonWithAuxAdam.
|
||||
"""
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
if group["use_muon"]:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 0.02)
|
||||
group["momentum"] = group.get("momentum", 0.95)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||
else:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 3e-4)
|
||||
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||
group["eps"] = group.get("eps", 1e-10)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||
super().__init__(param_groups, dict())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
@ -15,6 +15,7 @@ from torch.utils.data import ConcatDataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.amp import autocast
|
||||
import wandb
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@ -26,6 +27,7 @@ from transformers import (
|
||||
from typing import Optional, Dict, Any
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -321,24 +323,56 @@ def main():
|
||||
model.to(device)
|
||||
|
||||
# Set up optimizer
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": config.training.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if config.training.optim == "adamw_torch":
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": config.training.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=float(config.training.learning_rate),
|
||||
betas=(config.training.adam_beta1, config.training.adam_beta2),
|
||||
eps=config.training.adam_epsilon,
|
||||
)
|
||||
elif config.training.optim == "muon":
|
||||
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
|
||||
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
|
||||
|
||||
# Create Adam groups with different learning rates
|
||||
adam_groups = [
|
||||
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
|
||||
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
|
||||
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
|
||||
]
|
||||
|
||||
# Add Adam hyperparameters to groups
|
||||
for g in adam_groups:
|
||||
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
|
||||
g["eps"] = config.training.adam_epsilon
|
||||
g["weight_decay"] = config.training.weight_decay
|
||||
|
||||
# Create Muon group
|
||||
muon_group = dict(
|
||||
params=hidden_matrix_params,
|
||||
lr=float(config.training.learning_rate),
|
||||
momentum=config.training.muon_momentum,
|
||||
weight_decay=config.training.weight_decay,
|
||||
use_muon=True
|
||||
)
|
||||
|
||||
# Combine all groups
|
||||
param_groups = [*adam_groups, muon_group]
|
||||
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
|
||||
else:
|
||||
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
|
||||
|
||||
@ -428,6 +462,9 @@ def main():
|
||||
epoch_iterator = iter(train_dataloader)
|
||||
break
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
|
||||
|
||||
while samples_seen < max_train_samples and global_step < max_train_steps:
|
||||
try:
|
||||
batch = next(epoch_iterator)
|
||||
@ -449,6 +486,9 @@ def main():
|
||||
num_losses_accumulated += 1
|
||||
samples_seen += config.training.per_device_train_batch_size
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(config.training.per_device_train_batch_size)
|
||||
|
||||
# Check if we should do a gradient update
|
||||
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
|
||||
# Clip gradients
|
||||
@ -462,6 +502,16 @@ def main():
|
||||
global_step += 1
|
||||
current_epoch = samples_seen / len(train_dataset)
|
||||
|
||||
# Update progress bar with current stats
|
||||
current_lr = lr_scheduler.get_last_lr()[0]
|
||||
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||
pbar.set_postfix({
|
||||
'loss': f'{avg_loss:.4f}',
|
||||
'lr': f'{current_lr:.2e}',
|
||||
'epoch': f'{current_epoch:.2f}',
|
||||
'step': global_step
|
||||
})
|
||||
|
||||
# Logging
|
||||
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
||||
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||
@ -505,10 +555,16 @@ def main():
|
||||
# Check if we've reached our training limit
|
||||
if samples_seen >= max_train_samples or global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Close progress bar
|
||||
pbar.close()
|
||||
|
||||
# Save the final model
|
||||
logger.info("Saving final model...")
|
||||
model.save_pretrained(full_output_dir)
|
||||
# Save the final checkpoint with step number
|
||||
logger.info(f"Saving final checkpoint at step {global_step}...")
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
||||
full_output_dir, config.training.save_total_limit
|
||||
)
|
||||
|
||||
# Log final training state
|
||||
final_epoch = samples_seen / len(train_dataset)
|
||||
|
Loading…
x
Reference in New Issue
Block a user