From d6915b704464c03fb97784e4423df30e5c78839d Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 28 Oct 2025 22:40:38 +0000 Subject: [PATCH] Fixing lora --- olmocr/train/train.py | 101 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 8 deletions(-) diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 35150fd..f77542d 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -36,6 +36,46 @@ logging.basicConfig( logger = logging.getLogger(__name__) +def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module: + """Wrap the model with a LoRA adapter according to the configuration.""" + try: + from peft import LoraConfig, get_peft_model + except ImportError as exc: # pragma: no cover - optional dependency guard + raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc + + lora_kwargs = dict( + r=model_cfg.lora_rank, + lora_alpha=model_cfg.lora_alpha, + lora_dropout=model_cfg.lora_dropout, + target_modules=model_cfg.lora_target_modules, + bias="none", + task_type="CAUSAL_LM", + ) + if model_cfg.lora_modules_to_save: + lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save + + lora_config = LoraConfig(**lora_kwargs) + model = get_peft_model(model, lora_config) + + if hasattr(model, "config"): + model.config.base_model_name_or_path = model_cfg.name + base_model = getattr(model, "base_model", None) + if base_model is not None: + inner_model = getattr(base_model, "model", None) + if inner_model is not None and hasattr(inner_model, "config"): + inner_model.config._name_or_path = model_cfg.name + + if hasattr(model, "print_trainable_parameters"): + model.print_trainable_parameters() + + return model + + +def is_lora_checkpoint(checkpoint_dir: str) -> bool: + """Detect whether a checkpoint directory contains LoRA adapter weights.""" + return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json")) + + class QwenDataCollator: """Data collator for vision-language models that handles numpy arrays.""" @@ -140,9 +180,29 @@ def load_checkpoint( lr_scheduler: Any, checkpoint_dir: str, device: torch.device, + *, + base_model_path: Optional[str] = None, + use_lora: bool = False, ) -> tuple[torch.nn.Module, Dict[str, Any]]: """Load model, optimizer, scheduler, and training state from checkpoint.""" - model = model_class.from_pretrained(checkpoint_dir, **init_kwargs) + checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir) + + if checkpoint_has_lora or use_lora: + if base_model_path is None: + raise ValueError("base_model_path must be provided when loading LoRA checkpoints.") + + try: + from peft import PeftModel + except ImportError as exc: # pragma: no cover - optional dependency guard + raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc + + base_model = model_class.from_pretrained(base_model_path, **init_kwargs) + model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True) + if hasattr(model, "config"): + model.config.base_model_name_or_path = base_model_path + else: + model = model_class.from_pretrained(checkpoint_dir, **init_kwargs) + model.to(device) optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device)) @@ -280,6 +340,15 @@ def main(): else: raise NotImplementedError() + if config.model.use_lora: + logger.info("Applying LoRA adapters as specified in the config.") + model = prepare_lora_model(model, config.model) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0 + logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)") + # Enable gradient checkpointing if configured if config.training.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs) @@ -370,15 +439,19 @@ def main(): logger.info("Model compilation complete") # Set up optimizer + trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + if not trainable_named_params: + raise ValueError("No trainable parameters found. Check model fine-tuning configuration.") + if config.training.optim == "adamw_torch": no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)], "weight_decay": config.training.weight_decay, }, { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] @@ -389,11 +462,14 @@ def main(): eps=float(config.training.adam_epsilon), ) elif config.training.optim == "muon": + if config.model.use_lora: + raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.") + # Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head) - hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n] - embed_params = [p for n, p in model.named_parameters() if "embed" in n] - scalar_params = [p for p in model.parameters() if p.ndim < 2] - head_params = [p for n, p in model.named_parameters() if "lm_head" in n] + hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n] + embed_params = [p for n, p in trainable_named_params if "embed" in n] + scalar_params = [p for n, p in trainable_named_params if p.ndim < 2] + head_params = [p for n, p in trainable_named_params if "lm_head" in n] # Create Adam groups with different learning rates adam_groups = [ @@ -447,7 +523,16 @@ def main(): best_metric = float("inf") if not config.training.greater_is_better else -float("inf") if found_resumable_checkpoint: - model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device) + model, state = load_checkpoint( + model_class, + model_init_kwargs, + optimizer, + lr_scheduler, + found_resumable_checkpoint, + device, + base_model_path=config.model.name, + use_lora=config.model.use_lora, + ) global_step = state["global_step"] best_metric = state["best_metric"] samples_seen = state["samples_seen"]