mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-05 11:41:19 +00:00
Merge branch 'jakep/lorafix'
This commit is contained in:
commit
3d2c977ac5
@ -28,7 +28,7 @@ dataset:
|
|||||||
|
|
||||||
train:
|
train:
|
||||||
- name: finetuning_data
|
- name: finetuning_data
|
||||||
root_dir: /root/test-berkshire-data
|
root_dir: /root/test-berkshire-data/train
|
||||||
pipeline: &basic_pipeline
|
pipeline: &basic_pipeline
|
||||||
- name: FrontMatterParser
|
- name: FrontMatterParser
|
||||||
front_matter_class: PageResponse
|
front_matter_class: PageResponse
|
||||||
@ -50,7 +50,7 @@ dataset:
|
|||||||
|
|
||||||
eval:
|
eval:
|
||||||
- name: eval_finetuning_data
|
- name: eval_finetuning_data
|
||||||
root_dir: /root/test-berkshire-data
|
root_dir: /root/test-berkshire-data/test
|
||||||
pipeline: *basic_pipeline
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
# Training configuration
|
# Training configuration
|
||||||
|
|||||||
@ -36,12 +36,15 @@ import json
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
from smart_open import smart_open
|
from smart_open import smart_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
@ -59,6 +62,12 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
|
|||||||
# Supported model architectures
|
# Supported model architectures
|
||||||
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||||
|
|
||||||
|
# Map architectures to corresponding model classes
|
||||||
|
MODEL_CLASS_MAP = {
|
||||||
|
"Qwen2VLForConditionalGeneration": Qwen2VLForConditionalGeneration,
|
||||||
|
"Qwen2_5_VLForConditionalGeneration": Qwen2_5_VLForConditionalGeneration,
|
||||||
|
}
|
||||||
|
|
||||||
# Files to exclude from copying (training-related files)
|
# Files to exclude from copying (training-related files)
|
||||||
# Supports exact matches and glob patterns
|
# Supports exact matches and glob patterns
|
||||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
|
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
|
||||||
@ -79,6 +88,42 @@ def is_s3_path(path: str) -> bool:
|
|||||||
return path.startswith("s3://")
|
return path.startswith("s3://")
|
||||||
|
|
||||||
|
|
||||||
|
def join_path(base: str, *parts: str) -> str:
|
||||||
|
"""Join paths for local and S3-style URIs."""
|
||||||
|
if not parts:
|
||||||
|
return base
|
||||||
|
if is_s3_path(base):
|
||||||
|
cleaned = [base.rstrip("/")]
|
||||||
|
for part in parts:
|
||||||
|
cleaned.append(part.strip("/"))
|
||||||
|
return "/".join(segment for segment in cleaned if segment)
|
||||||
|
return os.path.join(base, *parts)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json_if_exists(path: str) -> Optional[dict]:
|
||||||
|
"""Load JSON from a path if it exists, otherwise return None."""
|
||||||
|
try:
|
||||||
|
with smart_open(path, "r") as handle:
|
||||||
|
return json.load(handle)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
except ClientError as exc:
|
||||||
|
error_code = exc.response.get("Error", {}).get("Code")
|
||||||
|
if error_code in {"NoSuchKey", "404"}:
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
except OSError as exc:
|
||||||
|
if "No such file" in str(exc):
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def load_adapter_config(source_path: str) -> Optional[dict]:
|
||||||
|
"""Return the LoRA adapter configuration if present for the given source."""
|
||||||
|
adapter_config_path = join_path(source_path, "adapter_config.json")
|
||||||
|
return load_json_if_exists(adapter_config_path)
|
||||||
|
|
||||||
|
|
||||||
def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str) -> None:
|
def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str) -> None:
|
||||||
"""Download a file from Hugging Face model repository."""
|
"""Download a file from Hugging Face model repository."""
|
||||||
url = f"{hf_base_url}/{filename}"
|
url = f"{hf_base_url}/{filename}"
|
||||||
@ -254,6 +299,61 @@ def get_weight_files(dir_path: str) -> list[str]:
|
|||||||
return weight_files
|
return weight_files
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora_adapter_checkpoint(adapter_dir: str, base_model_name: str, output_dir: str) -> str:
|
||||||
|
"""Merge a LoRA adapter into its base model and save the merged weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The detected architecture string of the merged model.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from peft import PeftModel
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||||
|
raise ImportError("Merging LoRA adapters requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||||
|
|
||||||
|
print(f"Merging LoRA adapter from {adapter_dir} with base model '{base_model_name}'...")
|
||||||
|
base_config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
architecture = None
|
||||||
|
for arch in base_config.architectures or []:
|
||||||
|
if arch in MODEL_CLASS_MAP:
|
||||||
|
architecture = arch
|
||||||
|
break
|
||||||
|
|
||||||
|
if architecture is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Base model '{base_model_name}' uses an unsupported architecture: {base_config.architectures}. "
|
||||||
|
f"Supported architectures: {SUPPORTED_ARCHITECTURES}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_class = MODEL_CLASS_MAP[architecture]
|
||||||
|
base_model = model_class.from_pretrained(
|
||||||
|
base_model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_model = PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=False)
|
||||||
|
merged_model = lora_model.merge_and_unload()
|
||||||
|
merged_model = merged_model.to("cpu")
|
||||||
|
|
||||||
|
if hasattr(merged_model, "config"):
|
||||||
|
merged_model.config._name_or_path = base_model_name
|
||||||
|
merged_model.config.base_model_name_or_path = base_model_name
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
merged_model.save_pretrained(output_dir)
|
||||||
|
print(f"✓ Saved merged model to {output_dir}")
|
||||||
|
|
||||||
|
# Explicit cleanup
|
||||||
|
del merged_model
|
||||||
|
del lora_model
|
||||||
|
del base_model
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return architecture
|
||||||
|
|
||||||
|
|
||||||
def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||||
"""Prepare OlmOCR checkpoint(s) for deployment, with support for souping."""
|
"""Prepare OlmOCR checkpoint(s) for deployment, with support for souping."""
|
||||||
print(f"Preparing {'souped ' if len(sources) > 1 else ''}checkpoint from {len(sources)} source(s) to {dest_path}")
|
print(f"Preparing {'souped ' if len(sources) > 1 else ''}checkpoint from {len(sources)} source(s) to {dest_path}")
|
||||||
@ -261,43 +361,69 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
|||||||
sources = [source.rstrip("/") for source in sources]
|
sources = [source.rstrip("/") for source in sources]
|
||||||
dest_path = dest_path.rstrip("/")
|
dest_path = dest_path.rstrip("/")
|
||||||
|
|
||||||
# Detect architectures
|
source_infos = []
|
||||||
architectures = []
|
|
||||||
for source in sources:
|
for source in sources:
|
||||||
config_path = f"{source}/config.json" if is_s3_path(source) else os.path.join(source, "config.json")
|
adapter_config = load_adapter_config(source)
|
||||||
|
if adapter_config is not None:
|
||||||
|
source_infos.append({"path": source, "is_lora": True, "adapter_config": adapter_config})
|
||||||
|
else:
|
||||||
|
config_path = join_path(source, "config.json")
|
||||||
arch = detect_checkpoint_architecture(config_path)
|
arch = detect_checkpoint_architecture(config_path)
|
||||||
architectures.append(arch)
|
source_infos.append({"path": source, "is_lora": False, "architecture": arch})
|
||||||
|
|
||||||
# Check all same
|
num_lora_sources = sum(1 for info in source_infos if info["is_lora"])
|
||||||
|
final_architecture: Optional[str] = None
|
||||||
|
|
||||||
|
if num_lora_sources > 0:
|
||||||
|
if len(source_infos) > 1:
|
||||||
|
raise ValueError("LoRA adapter checkpoints can only be processed individually, not during souping.")
|
||||||
|
|
||||||
|
source_info = source_infos[0]
|
||||||
|
source_path = source_info["path"]
|
||||||
|
adapter_config = source_info["adapter_config"]
|
||||||
|
base_model_name = adapter_config.get("base_model_name_or_path")
|
||||||
|
if not base_model_name:
|
||||||
|
raise ValueError("adapter_config.json is missing 'base_model_name_or_path'; cannot merge LoRA adapter.")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
adapter_local_dir = os.path.join(temp_dir, "adapter")
|
||||||
|
print("\nDownloading LoRA adapter locally for merging...")
|
||||||
|
if is_s3_path(source_path):
|
||||||
|
bucket, prefix = parse_s3_path(source_path)
|
||||||
|
copy_s3_to_local(bucket, prefix, adapter_local_dir)
|
||||||
|
else:
|
||||||
|
copy_local_to_local(source_path, adapter_local_dir)
|
||||||
|
|
||||||
|
merged_dir = os.path.join(temp_dir, "merged")
|
||||||
|
final_architecture = merge_lora_adapter_checkpoint(adapter_local_dir, base_model_name, merged_dir)
|
||||||
|
|
||||||
|
print("\nCopying merged model files to destination...")
|
||||||
|
if is_s3_path(dest_path):
|
||||||
|
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||||
|
copy_local_to_s3(merged_dir, dest_bucket, dest_prefix)
|
||||||
|
else:
|
||||||
|
copy_local_to_local(merged_dir, dest_path)
|
||||||
|
else:
|
||||||
|
architectures = [info["architecture"] for info in source_infos]
|
||||||
if len(set(architectures)) > 1:
|
if len(set(architectures)) > 1:
|
||||||
raise ValueError("All sources must have the same architecture")
|
raise ValueError("All sources must have the same architecture")
|
||||||
|
|
||||||
architecture = architectures[0]
|
final_architecture = architectures[0]
|
||||||
|
|
||||||
# Get the appropriate HF model ID and base URL
|
|
||||||
hf_model_id = HF_MODEL_IDS[architecture]
|
|
||||||
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
|
|
||||||
print(f"Using HuggingFace model: {hf_model_id}")
|
|
||||||
|
|
||||||
if len(sources) == 1:
|
if len(sources) == 1:
|
||||||
source_path = sources[0]
|
source_path = sources[0]
|
||||||
# Single checkpoint: copy as before
|
|
||||||
print("\nCopying model files...")
|
print("\nCopying model files...")
|
||||||
if is_s3_path(source_path) and is_s3_path(dest_path):
|
if is_s3_path(source_path) and is_s3_path(dest_path):
|
||||||
# S3 to S3
|
|
||||||
source_bucket, source_prefix = parse_s3_path(source_path)
|
source_bucket, source_prefix = parse_s3_path(source_path)
|
||||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||||
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
|
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
|
||||||
elif is_s3_path(source_path) and not is_s3_path(dest_path):
|
elif is_s3_path(source_path) and not is_s3_path(dest_path):
|
||||||
# S3 to local
|
|
||||||
source_bucket, source_prefix = parse_s3_path(source_path)
|
source_bucket, source_prefix = parse_s3_path(source_path)
|
||||||
copy_s3_to_local(source_bucket, source_prefix, dest_path)
|
copy_s3_to_local(source_bucket, source_prefix, dest_path)
|
||||||
elif not is_s3_path(source_path) and is_s3_path(dest_path):
|
elif not is_s3_path(source_path) and is_s3_path(dest_path):
|
||||||
# Local to S3
|
|
||||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||||
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
|
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
|
||||||
else:
|
else:
|
||||||
# Local to local
|
|
||||||
copy_local_to_local(source_path, dest_path)
|
copy_local_to_local(source_path, dest_path)
|
||||||
else:
|
else:
|
||||||
# Souping multiple checkpoints
|
# Souping multiple checkpoints
|
||||||
@ -402,6 +528,13 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
|||||||
else:
|
else:
|
||||||
copy_local_to_local(souped_dir, dest_path)
|
copy_local_to_local(souped_dir, dest_path)
|
||||||
|
|
||||||
|
if final_architecture is None:
|
||||||
|
raise ValueError("Unable to determine the architecture of the prepared checkpoint.")
|
||||||
|
|
||||||
|
hf_model_id = HF_MODEL_IDS[final_architecture]
|
||||||
|
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
|
||||||
|
print(f"Using HuggingFace model: {hf_model_id}")
|
||||||
|
|
||||||
# Download tokenizer files from Hugging Face
|
# Download tokenizer files from Hugging Face
|
||||||
print("\nDownloading tokenizer files from Hugging Face...")
|
print("\nDownloading tokenizer files from Hugging Face...")
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,46 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module:
|
||||||
|
"""Wrap the model with a LoRA adapter according to the configuration."""
|
||||||
|
try:
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||||
|
raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||||
|
|
||||||
|
lora_kwargs = dict(
|
||||||
|
r=model_cfg.lora_rank,
|
||||||
|
lora_alpha=model_cfg.lora_alpha,
|
||||||
|
lora_dropout=model_cfg.lora_dropout,
|
||||||
|
target_modules=model_cfg.lora_target_modules,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
if model_cfg.lora_modules_to_save:
|
||||||
|
lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save
|
||||||
|
|
||||||
|
lora_config = LoraConfig(**lora_kwargs)
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model.config.base_model_name_or_path = model_cfg.name
|
||||||
|
base_model = getattr(model, "base_model", None)
|
||||||
|
if base_model is not None:
|
||||||
|
inner_model = getattr(base_model, "model", None)
|
||||||
|
if inner_model is not None and hasattr(inner_model, "config"):
|
||||||
|
inner_model.config._name_or_path = model_cfg.name
|
||||||
|
|
||||||
|
if hasattr(model, "print_trainable_parameters"):
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def is_lora_checkpoint(checkpoint_dir: str) -> bool:
|
||||||
|
"""Detect whether a checkpoint directory contains LoRA adapter weights."""
|
||||||
|
return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json"))
|
||||||
|
|
||||||
|
|
||||||
class QwenDataCollator:
|
class QwenDataCollator:
|
||||||
"""Data collator for vision-language models that handles numpy arrays."""
|
"""Data collator for vision-language models that handles numpy arrays."""
|
||||||
|
|
||||||
@ -140,9 +180,29 @@ def load_checkpoint(
|
|||||||
lr_scheduler: Any,
|
lr_scheduler: Any,
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
*,
|
||||||
|
base_model_path: Optional[str] = None,
|
||||||
|
use_lora: bool = False,
|
||||||
) -> tuple[torch.nn.Module, Dict[str, Any]]:
|
) -> tuple[torch.nn.Module, Dict[str, Any]]:
|
||||||
"""Load model, optimizer, scheduler, and training state from checkpoint."""
|
"""Load model, optimizer, scheduler, and training state from checkpoint."""
|
||||||
|
checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
|
if checkpoint_has_lora or use_lora:
|
||||||
|
if base_model_path is None:
|
||||||
|
raise ValueError("base_model_path must be provided when loading LoRA checkpoints.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from peft import PeftModel
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||||
|
raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||||
|
|
||||||
|
base_model = model_class.from_pretrained(base_model_path, **init_kwargs)
|
||||||
|
model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True)
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model.config.base_model_name_or_path = base_model_path
|
||||||
|
else:
|
||||||
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
|
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
|
||||||
@ -280,6 +340,15 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if config.model.use_lora:
|
||||||
|
logger.info("Applying LoRA adapters as specified in the config.")
|
||||||
|
model = prepare_lora_model(model, config.model)
|
||||||
|
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0
|
||||||
|
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)")
|
||||||
|
|
||||||
# Enable gradient checkpointing if configured
|
# Enable gradient checkpointing if configured
|
||||||
if config.training.gradient_checkpointing:
|
if config.training.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
||||||
@ -370,15 +439,19 @@ def main():
|
|||||||
logger.info("Model compilation complete")
|
logger.info("Model compilation complete")
|
||||||
|
|
||||||
# Set up optimizer
|
# Set up optimizer
|
||||||
|
trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
|
||||||
|
if not trainable_named_params:
|
||||||
|
raise ValueError("No trainable parameters found. Check model fine-tuning configuration.")
|
||||||
|
|
||||||
if config.training.optim == "adamw_torch":
|
if config.training.optim == "adamw_torch":
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{
|
{
|
||||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
"params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)],
|
||||||
"weight_decay": config.training.weight_decay,
|
"weight_decay": config.training.weight_decay,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
"params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)],
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
@ -389,11 +462,14 @@ def main():
|
|||||||
eps=float(config.training.adam_epsilon),
|
eps=float(config.training.adam_epsilon),
|
||||||
)
|
)
|
||||||
elif config.training.optim == "muon":
|
elif config.training.optim == "muon":
|
||||||
|
if config.model.use_lora:
|
||||||
|
raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.")
|
||||||
|
|
||||||
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
|
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
|
||||||
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
embed_params = [p for n, p in trainable_named_params if "embed" in n]
|
||||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
scalar_params = [p for n, p in trainable_named_params if p.ndim < 2]
|
||||||
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
|
head_params = [p for n, p in trainable_named_params if "lm_head" in n]
|
||||||
|
|
||||||
# Create Adam groups with different learning rates
|
# Create Adam groups with different learning rates
|
||||||
adam_groups = [
|
adam_groups = [
|
||||||
@ -447,7 +523,16 @@ def main():
|
|||||||
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
|
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
|
||||||
|
|
||||||
if found_resumable_checkpoint:
|
if found_resumable_checkpoint:
|
||||||
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
|
model, state = load_checkpoint(
|
||||||
|
model_class,
|
||||||
|
model_init_kwargs,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
found_resumable_checkpoint,
|
||||||
|
device,
|
||||||
|
base_model_path=config.model.name,
|
||||||
|
use_lora=config.model.use_lora,
|
||||||
|
)
|
||||||
global_step = state["global_step"]
|
global_step = state["global_step"]
|
||||||
best_metric = state["best_metric"]
|
best_metric = state["best_metric"]
|
||||||
samples_seen = state["samples_seen"]
|
samples_seen = state["samples_seen"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user