mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-04 19:21:08 +00:00
Fixing lora
This commit is contained in:
parent
4768db63d2
commit
d6915b7044
@ -36,6 +36,46 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module:
|
||||
"""Wrap the model with a LoRA adapter according to the configuration."""
|
||||
try:
|
||||
from peft import LoraConfig, get_peft_model
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||
|
||||
lora_kwargs = dict(
|
||||
r=model_cfg.lora_rank,
|
||||
lora_alpha=model_cfg.lora_alpha,
|
||||
lora_dropout=model_cfg.lora_dropout,
|
||||
target_modules=model_cfg.lora_target_modules,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
if model_cfg.lora_modules_to_save:
|
||||
lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save
|
||||
|
||||
lora_config = LoraConfig(**lora_kwargs)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if hasattr(model, "config"):
|
||||
model.config.base_model_name_or_path = model_cfg.name
|
||||
base_model = getattr(model, "base_model", None)
|
||||
if base_model is not None:
|
||||
inner_model = getattr(base_model, "model", None)
|
||||
if inner_model is not None and hasattr(inner_model, "config"):
|
||||
inner_model.config._name_or_path = model_cfg.name
|
||||
|
||||
if hasattr(model, "print_trainable_parameters"):
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def is_lora_checkpoint(checkpoint_dir: str) -> bool:
|
||||
"""Detect whether a checkpoint directory contains LoRA adapter weights."""
|
||||
return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json"))
|
||||
|
||||
|
||||
class QwenDataCollator:
|
||||
"""Data collator for vision-language models that handles numpy arrays."""
|
||||
|
||||
@ -140,9 +180,29 @@ def load_checkpoint(
|
||||
lr_scheduler: Any,
|
||||
checkpoint_dir: str,
|
||||
device: torch.device,
|
||||
*,
|
||||
base_model_path: Optional[str] = None,
|
||||
use_lora: bool = False,
|
||||
) -> tuple[torch.nn.Module, Dict[str, Any]]:
|
||||
"""Load model, optimizer, scheduler, and training state from checkpoint."""
|
||||
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
||||
checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir)
|
||||
|
||||
if checkpoint_has_lora or use_lora:
|
||||
if base_model_path is None:
|
||||
raise ValueError("base_model_path must be provided when loading LoRA checkpoints.")
|
||||
|
||||
try:
|
||||
from peft import PeftModel
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||
|
||||
base_model = model_class.from_pretrained(base_model_path, **init_kwargs)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True)
|
||||
if hasattr(model, "config"):
|
||||
model.config.base_model_name_or_path = base_model_path
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
|
||||
@ -280,6 +340,15 @@ def main():
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if config.model.use_lora:
|
||||
logger.info("Applying LoRA adapters as specified in the config.")
|
||||
model = prepare_lora_model(model, config.model)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0
|
||||
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)")
|
||||
|
||||
# Enable gradient checkpointing if configured
|
||||
if config.training.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
||||
@ -370,15 +439,19 @@ def main():
|
||||
logger.info("Model compilation complete")
|
||||
|
||||
# Set up optimizer
|
||||
trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
|
||||
if not trainable_named_params:
|
||||
raise ValueError("No trainable parameters found. Check model fine-tuning configuration.")
|
||||
|
||||
if config.training.optim == "adamw_torch":
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": config.training.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
@ -389,11 +462,14 @@ def main():
|
||||
eps=float(config.training.adam_epsilon),
|
||||
)
|
||||
elif config.training.optim == "muon":
|
||||
if config.model.use_lora:
|
||||
raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.")
|
||||
|
||||
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
|
||||
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
|
||||
hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||
embed_params = [p for n, p in trainable_named_params if "embed" in n]
|
||||
scalar_params = [p for n, p in trainable_named_params if p.ndim < 2]
|
||||
head_params = [p for n, p in trainable_named_params if "lm_head" in n]
|
||||
|
||||
# Create Adam groups with different learning rates
|
||||
adam_groups = [
|
||||
@ -447,7 +523,16 @@ def main():
|
||||
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
|
||||
|
||||
if found_resumable_checkpoint:
|
||||
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
|
||||
model, state = load_checkpoint(
|
||||
model_class,
|
||||
model_init_kwargs,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
found_resumable_checkpoint,
|
||||
device,
|
||||
base_model_path=config.model.name,
|
||||
use_lora=config.model.use_lora,
|
||||
)
|
||||
global_step = state["global_step"]
|
||||
best_metric = state["best_metric"]
|
||||
samples_seen = state["samples_seen"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user