From 0acad6faf6ee980c69f98cec27aa7e7a4c7b0555 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Fri, 25 Jul 2025 17:29:12 +0000 Subject: [PATCH] Adding muon support --- olmocr/train/config.py | 8 +- olmocr/train/muon.py | 288 +++++++++++++++++++++++++++++++++++++++++ olmocr/train/train.py | 84 ++++++++++-- 3 files changed, 365 insertions(+), 15 deletions(-) create mode 100644 olmocr/train/muon.py diff --git a/olmocr/train/config.py b/olmocr/train/config.py index ff2e3a1..37e46d5 100644 --- a/olmocr/train/config.py +++ b/olmocr/train/config.py @@ -165,12 +165,18 @@ class TrainingConfig: lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict) # Optimization - optim: str = "adamw_torch" + optim: str = "adamw_torch" # "adamw_torch", "muon" adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-8 weight_decay: float = 0.01 max_grad_norm: float = 1.0 + + # Muon optimizer specific settings + muon_momentum: float = 0.95 + muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters + muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters + muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters # Gradient checkpointing gradient_checkpointing: bool = False diff --git a/olmocr/train/muon.py b/olmocr/train/muon.py new file mode 100644 index 0000000..771e9c7 --- /dev/null +++ b/olmocr/train/muon.py @@ -0,0 +1,288 @@ +# FROM: https://raw.githubusercontent.com/KellerJordan/Muon/refs/heads/master/muon.py + +import torch +import torch.distributed as dist + + +def zeropower_via_newtonschulz5(G, steps: int): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): + momentum.lerp_(grad, 1 - beta) + update = grad.lerp_(momentum, beta) if nesterov else momentum + if update.ndim == 4: # for the case of conv filters + update = update.view(len(update), -1) + update = zeropower_via_newtonschulz5(update, steps=ns_steps) + update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + return update + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the + advantage that it can be stably run in bfloat16 on the GPU. + + Muon should only be used for hidden weight layers. The input embedding, final output layer, + and any internal gains or biases should be optimized using a standard method such as AdamW. + Hidden convolutional weights can be trained using Muon by viewing them as 2D and then + collapsing their last 3 dimensions. + + Arguments: + lr: The learning rate, in units of spectral norm per update. + weight_decay: The AdamW-style weight decay. + momentum: The momentum. A value of 0.95 here is usually fine. + """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) + params = sorted(params, key=lambda x: x.size(), reverse=True) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + + return loss + + +class SingleDeviceMuon(torch.optim.Optimizer): + """ + Muon variant for usage in non-distributed settings. + """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + + return loss + + +def adam_update(grad, buf1, buf2, step, betas, eps): + buf1.lerp_(grad, 1 - betas[0]) + buf2.lerp_(grad.square(), 1 - betas[1]) + buf1c = buf1 / (1 - betas[0]**step) + buf2c = buf2 / (1 - betas[1]**step) + return buf1c / (buf2c.sqrt() + eps) + + +class MuonWithAuxAdam(torch.optim.Optimizer): + """ + Distributed Muon variant that can be used for all parameters in the network, since it runs an + internal AdamW for the parameters that are not compatible with Muon. The user must manually + specify which parameters shall be optimized with Muon and which with Adam by passing in a + list of param_groups with the `use_muon` flag set. + + The point of this class is to allow the user to have a single optimizer in their code, rather + than having both a Muon and an Adam which each need to be stepped. + + You can see an example usage below: + + https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470 + ``` + hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] + embed_params = [p for n, p in model.named_parameters() if "embed" in n] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [model.lm_head.weight] + + from muon import MuonWithAuxAdam + adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] + adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups] + muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True) + param_groups = [*adam_groups, muon_group] + optimizer = MuonWithAuxAdam(param_groups) + ``` + """ + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss + + +class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): + """ + Non-distributed variant of MuonWithAuxAdam. + """ + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss \ No newline at end of file diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 26e6843..350f494 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -15,6 +15,7 @@ from torch.utils.data import ConcatDataset, DataLoader from torch.optim import AdamW from torch.amp import autocast import wandb +from tqdm import tqdm from transformers import ( AutoProcessor, @@ -26,6 +27,7 @@ from transformers import ( from typing import Optional, Dict, Any from olmocr.train.config import Config from olmocr.train.dataloader import BaseMarkdownPDFDataset +from olmocr.train.muon import SingleDeviceMuonWithAuxAdam # Configure logging logging.basicConfig( @@ -321,24 +323,56 @@ def main(): model.to(device) # Set up optimizer - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": config.training.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] if config.training.optim == "adamw_torch": + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": config.training.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] optimizer = AdamW( optimizer_grouped_parameters, lr=float(config.training.learning_rate), betas=(config.training.adam_beta1, config.training.adam_beta2), eps=config.training.adam_epsilon, ) + elif config.training.optim == "muon": + # Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head) + hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n] + embed_params = [p for n, p in model.named_parameters() if "embed" in n] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [p for n, p in model.named_parameters() if "lm_head" in n] + + # Create Adam groups with different learning rates + adam_groups = [ + dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False), + dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False), + dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False) + ] + + # Add Adam hyperparameters to groups + for g in adam_groups: + g["betas"] = (config.training.adam_beta1, config.training.adam_beta2) + g["eps"] = config.training.adam_epsilon + g["weight_decay"] = config.training.weight_decay + + # Create Muon group + muon_group = dict( + params=hidden_matrix_params, + lr=float(config.training.learning_rate), + momentum=config.training.muon_momentum, + weight_decay=config.training.weight_decay, + use_muon=True + ) + + # Combine all groups + param_groups = [*adam_groups, muon_group] + optimizer = SingleDeviceMuonWithAuxAdam(param_groups) else: raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop") @@ -428,6 +462,9 @@ def main(): epoch_iterator = iter(train_dataloader) break + # Create progress bar + pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples") + while samples_seen < max_train_samples and global_step < max_train_steps: try: batch = next(epoch_iterator) @@ -449,6 +486,9 @@ def main(): num_losses_accumulated += 1 samples_seen += config.training.per_device_train_batch_size + # Update progress bar + pbar.update(config.training.per_device_train_batch_size) + # Check if we should do a gradient update if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples: # Clip gradients @@ -462,6 +502,16 @@ def main(): global_step += 1 current_epoch = samples_seen / len(train_dataset) + # Update progress bar with current stats + current_lr = lr_scheduler.get_last_lr()[0] + avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 + pbar.set_postfix({ + 'loss': f'{avg_loss:.4f}', + 'lr': f'{current_lr:.2e}', + 'epoch': f'{current_epoch:.2f}', + 'step': global_step + }) + # Logging if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0: avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 @@ -505,10 +555,16 @@ def main(): # Check if we've reached our training limit if samples_seen >= max_train_samples or global_step >= max_train_steps: break + + # Close progress bar + pbar.close() - # Save the final model - logger.info("Saving final model...") - model.save_pretrained(full_output_dir) + # Save the final checkpoint with step number + logger.info(f"Saving final checkpoint at step {global_step}...") + save_checkpoint( + model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, + full_output_dir, config.training.save_total_limit + ) # Log final training state final_epoch = samples_seen / len(train_dataset)