mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 17:22:13 +00:00
Adding muon support
This commit is contained in:
parent
6d5711fa3e
commit
0acad6faf6
@ -165,13 +165,19 @@ class TrainingConfig:
|
|||||||
lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
|
lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
optim: str = "adamw_torch"
|
optim: str = "adamw_torch" # "adamw_torch", "muon"
|
||||||
adam_beta1: float = 0.9
|
adam_beta1: float = 0.9
|
||||||
adam_beta2: float = 0.999
|
adam_beta2: float = 0.999
|
||||||
adam_epsilon: float = 1e-8
|
adam_epsilon: float = 1e-8
|
||||||
weight_decay: float = 0.01
|
weight_decay: float = 0.01
|
||||||
max_grad_norm: float = 1.0
|
max_grad_norm: float = 1.0
|
||||||
|
|
||||||
|
# Muon optimizer specific settings
|
||||||
|
muon_momentum: float = 0.95
|
||||||
|
muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters
|
||||||
|
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
|
||||||
|
muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters
|
||||||
|
|
||||||
# Gradient checkpointing
|
# Gradient checkpointing
|
||||||
gradient_checkpointing: bool = False
|
gradient_checkpointing: bool = False
|
||||||
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})
|
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})
|
||||||
|
288
olmocr/train/muon.py
Normal file
288
olmocr/train/muon.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
# FROM: https://raw.githubusercontent.com/KellerJordan/Muon/refs/heads/master/muon.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def zeropower_via_newtonschulz5(G, steps: int):
|
||||||
|
"""
|
||||||
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||||
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||||
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||||
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||||
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||||
|
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||||
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||||
|
"""
|
||||||
|
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
X = G.bfloat16()
|
||||||
|
if G.size(-2) > G.size(-1):
|
||||||
|
X = X.mT
|
||||||
|
|
||||||
|
# Ensure spectral norm is at most 1
|
||||||
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||||
|
# Perform the NS iterations
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.mT
|
||||||
|
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||||
|
X = a * X + B @ X
|
||||||
|
|
||||||
|
if G.size(-2) > G.size(-1):
|
||||||
|
X = X.mT
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
||||||
|
momentum.lerp_(grad, 1 - beta)
|
||||||
|
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
||||||
|
if update.ndim == 4: # for the case of conv filters
|
||||||
|
update = update.view(len(update), -1)
|
||||||
|
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||||
|
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
|
||||||
|
return update
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||||
|
|
||||||
|
https://kellerjordan.github.io/posts/muon/
|
||||||
|
|
||||||
|
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||||
|
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||||
|
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
|
||||||
|
advantage that it can be stably run in bfloat16 on the GPU.
|
||||||
|
|
||||||
|
Muon should only be used for hidden weight layers. The input embedding, final output layer,
|
||||||
|
and any internal gains or biases should be optimized using a standard method such as AdamW.
|
||||||
|
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
|
||||||
|
collapsing their last 3 dimensions.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
lr: The learning rate, in units of spectral norm per update.
|
||||||
|
weight_decay: The AdamW-style weight decay.
|
||||||
|
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||||
|
"""
|
||||||
|
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||||
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||||
|
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||||
|
params = sorted(params, key=lambda x: x.size(), reverse=True)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params = group["params"]
|
||||||
|
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||||
|
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||||
|
if base_i + dist.get_rank() < len(params):
|
||||||
|
p = params[base_i + dist.get_rank()]
|
||||||
|
if p.grad is None:
|
||||||
|
# continue
|
||||||
|
p.grad = torch.zeros_like(p) # Force synchronization
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(p)
|
||||||
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||||
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||||
|
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class SingleDeviceMuon(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Muon variant for usage in non-distributed settings.
|
||||||
|
"""
|
||||||
|
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||||
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
# continue
|
||||||
|
p.grad = torch.zeros_like(p) # Force synchronization
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(p)
|
||||||
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||||
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||||
|
buf1.lerp_(grad, 1 - betas[0])
|
||||||
|
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||||
|
buf1c = buf1 / (1 - betas[0]**step)
|
||||||
|
buf2c = buf2 / (1 - betas[1]**step)
|
||||||
|
return buf1c / (buf2c.sqrt() + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Distributed Muon variant that can be used for all parameters in the network, since it runs an
|
||||||
|
internal AdamW for the parameters that are not compatible with Muon. The user must manually
|
||||||
|
specify which parameters shall be optimized with Muon and which with Adam by passing in a
|
||||||
|
list of param_groups with the `use_muon` flag set.
|
||||||
|
|
||||||
|
The point of this class is to allow the user to have a single optimizer in their code, rather
|
||||||
|
than having both a Muon and an Adam which each need to be stepped.
|
||||||
|
|
||||||
|
You can see an example usage below:
|
||||||
|
|
||||||
|
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
|
||||||
|
```
|
||||||
|
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
|
||||||
|
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||||
|
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||||
|
head_params = [model.lm_head.weight]
|
||||||
|
|
||||||
|
from muon import MuonWithAuxAdam
|
||||||
|
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
|
||||||
|
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
|
||||||
|
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
|
||||||
|
param_groups = [*adam_groups, muon_group]
|
||||||
|
optimizer = MuonWithAuxAdam(param_groups)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def __init__(self, param_groups):
|
||||||
|
for group in param_groups:
|
||||||
|
assert "use_muon" in group
|
||||||
|
if group["use_muon"]:
|
||||||
|
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
|
||||||
|
# defaults
|
||||||
|
group["lr"] = group.get("lr", 0.02)
|
||||||
|
group["momentum"] = group.get("momentum", 0.95)
|
||||||
|
group["weight_decay"] = group.get("weight_decay", 0)
|
||||||
|
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||||
|
else:
|
||||||
|
# defaults
|
||||||
|
group["lr"] = group.get("lr", 3e-4)
|
||||||
|
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||||
|
group["eps"] = group.get("eps", 1e-10)
|
||||||
|
group["weight_decay"] = group.get("weight_decay", 0)
|
||||||
|
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||||
|
super().__init__(param_groups, dict())
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
if group["use_muon"]:
|
||||||
|
params = group["params"]
|
||||||
|
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||||
|
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||||
|
if base_i + dist.get_rank() < len(params):
|
||||||
|
p = params[base_i + dist.get_rank()]
|
||||||
|
if p.grad is None:
|
||||||
|
# continue
|
||||||
|
p.grad = torch.zeros_like(p) # Force synchronization
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(p)
|
||||||
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||||
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||||
|
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||||
|
else:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
# continue
|
||||||
|
p.grad = torch.zeros_like(p) # Force synchronization
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
state["exp_avg"] = torch.zeros_like(p)
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
|
state["step"] = 0
|
||||||
|
state["step"] += 1
|
||||||
|
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||||
|
state["step"], group["betas"], group["eps"])
|
||||||
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
|
p.add_(update, alpha=-group["lr"])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Non-distributed variant of MuonWithAuxAdam.
|
||||||
|
"""
|
||||||
|
def __init__(self, param_groups):
|
||||||
|
for group in param_groups:
|
||||||
|
assert "use_muon" in group
|
||||||
|
if group["use_muon"]:
|
||||||
|
# defaults
|
||||||
|
group["lr"] = group.get("lr", 0.02)
|
||||||
|
group["momentum"] = group.get("momentum", 0.95)
|
||||||
|
group["weight_decay"] = group.get("weight_decay", 0)
|
||||||
|
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||||
|
else:
|
||||||
|
# defaults
|
||||||
|
group["lr"] = group.get("lr", 3e-4)
|
||||||
|
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||||
|
group["eps"] = group.get("eps", 1e-10)
|
||||||
|
group["weight_decay"] = group.get("weight_decay", 0)
|
||||||
|
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||||
|
super().__init__(param_groups, dict())
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
if group["use_muon"]:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
# continue
|
||||||
|
p.grad = torch.zeros_like(p) # Force synchronization
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(p)
|
||||||
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||||
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||||
|
else:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
# continue
|
||||||
|
p.grad = torch.zeros_like(p) # Force synchronization
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
state["exp_avg"] = torch.zeros_like(p)
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
|
state["step"] = 0
|
||||||
|
state["step"] += 1
|
||||||
|
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||||
|
state["step"], group["betas"], group["eps"])
|
||||||
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
|
p.add_(update, alpha=-group["lr"])
|
||||||
|
|
||||||
|
return loss
|
@ -15,6 +15,7 @@ from torch.utils.data import ConcatDataset, DataLoader
|
|||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.amp import autocast
|
from torch.amp import autocast
|
||||||
import wandb
|
import wandb
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
@ -26,6 +27,7 @@ from transformers import (
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from olmocr.train.config import Config
|
from olmocr.train.config import Config
|
||||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||||
|
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -321,24 +323,56 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# Set up optimizer
|
# Set up optimizer
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": config.training.weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
if config.training.optim == "adamw_torch":
|
if config.training.optim == "adamw_torch":
|
||||||
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": config.training.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
optimizer = AdamW(
|
optimizer = AdamW(
|
||||||
optimizer_grouped_parameters,
|
optimizer_grouped_parameters,
|
||||||
lr=float(config.training.learning_rate),
|
lr=float(config.training.learning_rate),
|
||||||
betas=(config.training.adam_beta1, config.training.adam_beta2),
|
betas=(config.training.adam_beta1, config.training.adam_beta2),
|
||||||
eps=config.training.adam_epsilon,
|
eps=config.training.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
elif config.training.optim == "muon":
|
||||||
|
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
|
||||||
|
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||||
|
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||||
|
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||||
|
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
|
||||||
|
|
||||||
|
# Create Adam groups with different learning rates
|
||||||
|
adam_groups = [
|
||||||
|
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
|
||||||
|
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
|
||||||
|
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add Adam hyperparameters to groups
|
||||||
|
for g in adam_groups:
|
||||||
|
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
|
||||||
|
g["eps"] = config.training.adam_epsilon
|
||||||
|
g["weight_decay"] = config.training.weight_decay
|
||||||
|
|
||||||
|
# Create Muon group
|
||||||
|
muon_group = dict(
|
||||||
|
params=hidden_matrix_params,
|
||||||
|
lr=float(config.training.learning_rate),
|
||||||
|
momentum=config.training.muon_momentum,
|
||||||
|
weight_decay=config.training.weight_decay,
|
||||||
|
use_muon=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all groups
|
||||||
|
param_groups = [*adam_groups, muon_group]
|
||||||
|
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
|
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
|
||||||
|
|
||||||
@ -428,6 +462,9 @@ def main():
|
|||||||
epoch_iterator = iter(train_dataloader)
|
epoch_iterator = iter(train_dataloader)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Create progress bar
|
||||||
|
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
|
||||||
|
|
||||||
while samples_seen < max_train_samples and global_step < max_train_steps:
|
while samples_seen < max_train_samples and global_step < max_train_steps:
|
||||||
try:
|
try:
|
||||||
batch = next(epoch_iterator)
|
batch = next(epoch_iterator)
|
||||||
@ -449,6 +486,9 @@ def main():
|
|||||||
num_losses_accumulated += 1
|
num_losses_accumulated += 1
|
||||||
samples_seen += config.training.per_device_train_batch_size
|
samples_seen += config.training.per_device_train_batch_size
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
pbar.update(config.training.per_device_train_batch_size)
|
||||||
|
|
||||||
# Check if we should do a gradient update
|
# Check if we should do a gradient update
|
||||||
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
|
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
|
||||||
# Clip gradients
|
# Clip gradients
|
||||||
@ -462,6 +502,16 @@ def main():
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
current_epoch = samples_seen / len(train_dataset)
|
current_epoch = samples_seen / len(train_dataset)
|
||||||
|
|
||||||
|
# Update progress bar with current stats
|
||||||
|
current_lr = lr_scheduler.get_last_lr()[0]
|
||||||
|
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||||
|
pbar.set_postfix({
|
||||||
|
'loss': f'{avg_loss:.4f}',
|
||||||
|
'lr': f'{current_lr:.2e}',
|
||||||
|
'epoch': f'{current_epoch:.2f}',
|
||||||
|
'step': global_step
|
||||||
|
})
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
||||||
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||||
@ -506,9 +556,15 @@ def main():
|
|||||||
if samples_seen >= max_train_samples or global_step >= max_train_steps:
|
if samples_seen >= max_train_samples or global_step >= max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Save the final model
|
# Close progress bar
|
||||||
logger.info("Saving final model...")
|
pbar.close()
|
||||||
model.save_pretrained(full_output_dir)
|
|
||||||
|
# Save the final checkpoint with step number
|
||||||
|
logger.info(f"Saving final checkpoint at step {global_step}...")
|
||||||
|
save_checkpoint(
|
||||||
|
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
||||||
|
full_output_dir, config.training.save_total_limit
|
||||||
|
)
|
||||||
|
|
||||||
# Log final training state
|
# Log final training state
|
||||||
final_epoch = samples_seen / len(train_dataset)
|
final_epoch = samples_seen / len(train_dataset)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user