Fixing lora

This commit is contained in:
Jake Poznanski 2025-10-28 22:40:38 +00:00
parent 4768db63d2
commit d6915b7044

View File

@ -36,6 +36,46 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module:
"""Wrap the model with a LoRA adapter according to the configuration."""
try:
from peft import LoraConfig, get_peft_model
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc
lora_kwargs = dict(
r=model_cfg.lora_rank,
lora_alpha=model_cfg.lora_alpha,
lora_dropout=model_cfg.lora_dropout,
target_modules=model_cfg.lora_target_modules,
bias="none",
task_type="CAUSAL_LM",
)
if model_cfg.lora_modules_to_save:
lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save
lora_config = LoraConfig(**lora_kwargs)
model = get_peft_model(model, lora_config)
if hasattr(model, "config"):
model.config.base_model_name_or_path = model_cfg.name
base_model = getattr(model, "base_model", None)
if base_model is not None:
inner_model = getattr(base_model, "model", None)
if inner_model is not None and hasattr(inner_model, "config"):
inner_model.config._name_or_path = model_cfg.name
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
return model
def is_lora_checkpoint(checkpoint_dir: str) -> bool:
"""Detect whether a checkpoint directory contains LoRA adapter weights."""
return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json"))
class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays."""
@ -140,9 +180,29 @@ def load_checkpoint(
lr_scheduler: Any,
checkpoint_dir: str,
device: torch.device,
*,
base_model_path: Optional[str] = None,
use_lora: bool = False,
) -> tuple[torch.nn.Module, Dict[str, Any]]:
"""Load model, optimizer, scheduler, and training state from checkpoint."""
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir)
if checkpoint_has_lora or use_lora:
if base_model_path is None:
raise ValueError("base_model_path must be provided when loading LoRA checkpoints.")
try:
from peft import PeftModel
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc
base_model = model_class.from_pretrained(base_model_path, **init_kwargs)
model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True)
if hasattr(model, "config"):
model.config.base_model_name_or_path = base_model_path
else:
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
model.to(device)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
@ -280,6 +340,15 @@ def main():
else:
raise NotImplementedError()
if config.model.use_lora:
logger.info("Applying LoRA adapters as specified in the config.")
model = prepare_lora_model(model, config.model)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)")
# Enable gradient checkpointing if configured
if config.training.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
@ -370,15 +439,19 @@ def main():
logger.info("Model compilation complete")
# Set up optimizer
trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
if not trainable_named_params:
raise ValueError("No trainable parameters found. Check model fine-tuning configuration.")
if config.training.optim == "adamw_torch":
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
@ -389,11 +462,14 @@ def main():
eps=float(config.training.adam_epsilon),
)
elif config.training.optim == "muon":
if config.model.use_lora:
raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.")
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in trainable_named_params if "embed" in n]
scalar_params = [p for n, p in trainable_named_params if p.ndim < 2]
head_params = [p for n, p in trainable_named_params if "lm_head" in n]
# Create Adam groups with different learning rates
adam_groups = [
@ -447,7 +523,16 @@ def main():
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
if found_resumable_checkpoint:
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
model, state = load_checkpoint(
model_class,
model_init_kwargs,
optimizer,
lr_scheduler,
found_resumable_checkpoint,
device,
base_model_path=config.model.name,
use_lora=config.model.use_lora,
)
global_step = state["global_step"]
best_metric = state["best_metric"]
samples_seen = state["samples_seen"]