mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-04 19:21:08 +00:00
Merge branch 'jakep/lorafix'
This commit is contained in:
commit
3d2c977ac5
@ -28,7 +28,7 @@ dataset:
|
||||
|
||||
train:
|
||||
- name: finetuning_data
|
||||
root_dir: /root/test-berkshire-data
|
||||
root_dir: /root/test-berkshire-data/train
|
||||
pipeline: &basic_pipeline
|
||||
- name: FrontMatterParser
|
||||
front_matter_class: PageResponse
|
||||
@ -50,7 +50,7 @@ dataset:
|
||||
|
||||
eval:
|
||||
- name: eval_finetuning_data
|
||||
root_dir: /root/test-berkshire-data
|
||||
root_dir: /root/test-berkshire-data/test
|
||||
pipeline: *basic_pipeline
|
||||
|
||||
# Training configuration
|
||||
|
||||
@ -36,12 +36,15 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
import torch
|
||||
from botocore.exceptions import ClientError
|
||||
from smart_open import smart_open
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file, save_file
|
||||
@ -59,6 +62,12 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
|
||||
# Supported model architectures
|
||||
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
# Map architectures to corresponding model classes
|
||||
MODEL_CLASS_MAP = {
|
||||
"Qwen2VLForConditionalGeneration": Qwen2VLForConditionalGeneration,
|
||||
"Qwen2_5_VLForConditionalGeneration": Qwen2_5_VLForConditionalGeneration,
|
||||
}
|
||||
|
||||
# Files to exclude from copying (training-related files)
|
||||
# Supports exact matches and glob patterns
|
||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
|
||||
@ -79,6 +88,42 @@ def is_s3_path(path: str) -> bool:
|
||||
return path.startswith("s3://")
|
||||
|
||||
|
||||
def join_path(base: str, *parts: str) -> str:
|
||||
"""Join paths for local and S3-style URIs."""
|
||||
if not parts:
|
||||
return base
|
||||
if is_s3_path(base):
|
||||
cleaned = [base.rstrip("/")]
|
||||
for part in parts:
|
||||
cleaned.append(part.strip("/"))
|
||||
return "/".join(segment for segment in cleaned if segment)
|
||||
return os.path.join(base, *parts)
|
||||
|
||||
|
||||
def load_json_if_exists(path: str) -> Optional[dict]:
|
||||
"""Load JSON from a path if it exists, otherwise return None."""
|
||||
try:
|
||||
with smart_open(path, "r") as handle:
|
||||
return json.load(handle)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except ClientError as exc:
|
||||
error_code = exc.response.get("Error", {}).get("Code")
|
||||
if error_code in {"NoSuchKey", "404"}:
|
||||
return None
|
||||
raise
|
||||
except OSError as exc:
|
||||
if "No such file" in str(exc):
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
def load_adapter_config(source_path: str) -> Optional[dict]:
|
||||
"""Return the LoRA adapter configuration if present for the given source."""
|
||||
adapter_config_path = join_path(source_path, "adapter_config.json")
|
||||
return load_json_if_exists(adapter_config_path)
|
||||
|
||||
|
||||
def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str) -> None:
|
||||
"""Download a file from Hugging Face model repository."""
|
||||
url = f"{hf_base_url}/{filename}"
|
||||
@ -254,6 +299,61 @@ def get_weight_files(dir_path: str) -> list[str]:
|
||||
return weight_files
|
||||
|
||||
|
||||
def merge_lora_adapter_checkpoint(adapter_dir: str, base_model_name: str, output_dir: str) -> str:
|
||||
"""Merge a LoRA adapter into its base model and save the merged weights.
|
||||
|
||||
Returns:
|
||||
The detected architecture string of the merged model.
|
||||
"""
|
||||
try:
|
||||
from peft import PeftModel
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise ImportError("Merging LoRA adapters requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||
|
||||
print(f"Merging LoRA adapter from {adapter_dir} with base model '{base_model_name}'...")
|
||||
base_config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True)
|
||||
|
||||
architecture = None
|
||||
for arch in base_config.architectures or []:
|
||||
if arch in MODEL_CLASS_MAP:
|
||||
architecture = arch
|
||||
break
|
||||
|
||||
if architecture is None:
|
||||
raise ValueError(
|
||||
f"Base model '{base_model_name}' uses an unsupported architecture: {base_config.architectures}. "
|
||||
f"Supported architectures: {SUPPORTED_ARCHITECTURES}"
|
||||
)
|
||||
|
||||
model_class = MODEL_CLASS_MAP[architecture]
|
||||
base_model = model_class.from_pretrained(
|
||||
base_model_name,
|
||||
trust_remote_code=True,
|
||||
torch_dtype="auto",
|
||||
)
|
||||
|
||||
lora_model = PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=False)
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
merged_model = merged_model.to("cpu")
|
||||
|
||||
if hasattr(merged_model, "config"):
|
||||
merged_model.config._name_or_path = base_model_name
|
||||
merged_model.config.base_model_name_or_path = base_model_name
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
merged_model.save_pretrained(output_dir)
|
||||
print(f"✓ Saved merged model to {output_dir}")
|
||||
|
||||
# Explicit cleanup
|
||||
del merged_model
|
||||
del lora_model
|
||||
del base_model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return architecture
|
||||
|
||||
|
||||
def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
"""Prepare OlmOCR checkpoint(s) for deployment, with support for souping."""
|
||||
print(f"Preparing {'souped ' if len(sources) > 1 else ''}checkpoint from {len(sources)} source(s) to {dest_path}")
|
||||
@ -261,146 +361,179 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
sources = [source.rstrip("/") for source in sources]
|
||||
dest_path = dest_path.rstrip("/")
|
||||
|
||||
# Detect architectures
|
||||
architectures = []
|
||||
source_infos = []
|
||||
for source in sources:
|
||||
config_path = f"{source}/config.json" if is_s3_path(source) else os.path.join(source, "config.json")
|
||||
arch = detect_checkpoint_architecture(config_path)
|
||||
architectures.append(arch)
|
||||
|
||||
# Check all same
|
||||
if len(set(architectures)) > 1:
|
||||
raise ValueError("All sources must have the same architecture")
|
||||
|
||||
architecture = architectures[0]
|
||||
|
||||
# Get the appropriate HF model ID and base URL
|
||||
hf_model_id = HF_MODEL_IDS[architecture]
|
||||
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
|
||||
print(f"Using HuggingFace model: {hf_model_id}")
|
||||
|
||||
if len(sources) == 1:
|
||||
source_path = sources[0]
|
||||
# Single checkpoint: copy as before
|
||||
print("\nCopying model files...")
|
||||
if is_s3_path(source_path) and is_s3_path(dest_path):
|
||||
# S3 to S3
|
||||
source_bucket, source_prefix = parse_s3_path(source_path)
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
|
||||
elif is_s3_path(source_path) and not is_s3_path(dest_path):
|
||||
# S3 to local
|
||||
source_bucket, source_prefix = parse_s3_path(source_path)
|
||||
copy_s3_to_local(source_bucket, source_prefix, dest_path)
|
||||
elif not is_s3_path(source_path) and is_s3_path(dest_path):
|
||||
# Local to S3
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
|
||||
adapter_config = load_adapter_config(source)
|
||||
if adapter_config is not None:
|
||||
source_infos.append({"path": source, "is_lora": True, "adapter_config": adapter_config})
|
||||
else:
|
||||
# Local to local
|
||||
copy_local_to_local(source_path, dest_path)
|
||||
else:
|
||||
# Souping multiple checkpoints
|
||||
config_path = join_path(source, "config.json")
|
||||
arch = detect_checkpoint_architecture(config_path)
|
||||
source_infos.append({"path": source, "is_lora": False, "architecture": arch})
|
||||
|
||||
num_lora_sources = sum(1 for info in source_infos if info["is_lora"])
|
||||
final_architecture: Optional[str] = None
|
||||
|
||||
if num_lora_sources > 0:
|
||||
if len(source_infos) > 1:
|
||||
raise ValueError("LoRA adapter checkpoints can only be processed individually, not during souping.")
|
||||
|
||||
source_info = source_infos[0]
|
||||
source_path = source_info["path"]
|
||||
adapter_config = source_info["adapter_config"]
|
||||
base_model_name = adapter_config.get("base_model_name_or_path")
|
||||
if not base_model_name:
|
||||
raise ValueError("adapter_config.json is missing 'base_model_name_or_path'; cannot merge LoRA adapter.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Download all sources to local temp dirs
|
||||
source_temps = []
|
||||
for i, source in enumerate(sources):
|
||||
source_temp = os.path.join(temp_dir, f"source_{i}")
|
||||
if is_s3_path(source):
|
||||
bucket, prefix = parse_s3_path(source)
|
||||
copy_s3_to_local(bucket, prefix, source_temp)
|
||||
else:
|
||||
copy_local_to_local(source, source_temp)
|
||||
source_temps.append(source_temp)
|
||||
adapter_local_dir = os.path.join(temp_dir, "adapter")
|
||||
print("\nDownloading LoRA adapter locally for merging...")
|
||||
if is_s3_path(source_path):
|
||||
bucket, prefix = parse_s3_path(source_path)
|
||||
copy_s3_to_local(bucket, prefix, adapter_local_dir)
|
||||
else:
|
||||
copy_local_to_local(source_path, adapter_local_dir)
|
||||
|
||||
first_source = source_temps[0]
|
||||
merged_dir = os.path.join(temp_dir, "merged")
|
||||
final_architecture = merge_lora_adapter_checkpoint(adapter_local_dir, base_model_name, merged_dir)
|
||||
|
||||
# Get weight files
|
||||
weight_full_paths = get_weight_files(first_source)
|
||||
weight_rel_paths = [os.path.relpath(p, first_source) for p in weight_full_paths]
|
||||
|
||||
# Verify others have same weight files
|
||||
for i in range(1, len(sources)):
|
||||
other_dir = source_temps[i]
|
||||
other_weights = [os.path.relpath(p, other_dir) for p in get_weight_files(other_dir)]
|
||||
if set(other_weights) != set(weight_rel_paths):
|
||||
raise ValueError(f"Source {sources[i]} has different weight files")
|
||||
|
||||
# Create souped_dir
|
||||
souped_dir = os.path.join(temp_dir, "souped")
|
||||
# Copy first source (including its weights, which will be overwritten)
|
||||
copy_local_to_local(first_source, souped_dir)
|
||||
|
||||
# Average weights
|
||||
for rel_path in tqdm(weight_rel_paths, desc="Averaging weight files"):
|
||||
all_paths = [os.path.join(st, rel_path) for st in source_temps]
|
||||
file_path = all_paths[0]
|
||||
souped_path = os.path.join(souped_dir, rel_path)
|
||||
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
|
||||
|
||||
if file_path.endswith(".safetensors"):
|
||||
sum_state = load_file(file_path, device="cpu")
|
||||
# Store original dtypes for each tensor
|
||||
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||
# Upconvert to at least fp32 for accurate averaging
|
||||
for k in sum_state:
|
||||
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
sum_state[k] = sum_state[k].to(torch.float32)
|
||||
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = load_file(other_path, device="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
# Upconvert other state to match sum_state dtype
|
||||
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
other_state[k] = other_state[k].to(torch.float32)
|
||||
sum_state[k] += other_state[k]
|
||||
del other_state
|
||||
|
||||
n = len(all_paths)
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
# Cast back to original dtype
|
||||
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||
save_file(sum_state, souped_path)
|
||||
elif file_path.endswith(".bin"):
|
||||
sum_state = torch.load(file_path, map_location="cpu")
|
||||
# Store original dtypes for each tensor
|
||||
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||
# Upconvert to at least fp32 for accurate averaging
|
||||
for k in sum_state:
|
||||
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
sum_state[k] = sum_state[k].to(torch.float32)
|
||||
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = torch.load(other_path, map_location="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
# Upconvert other state to match sum_state dtype
|
||||
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
other_state[k] = other_state[k].to(torch.float32)
|
||||
sum_state[k] += other_state[k]
|
||||
del other_state
|
||||
|
||||
n = len(all_paths)
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
# Cast back to original dtype
|
||||
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||
torch.save(sum_state, souped_path)
|
||||
else:
|
||||
print(f"Skipping unknown weight file: {rel_path}")
|
||||
continue
|
||||
|
||||
# Now copy souped_dir to dest_path
|
||||
print("\nCopying souped model files to destination...")
|
||||
print("\nCopying merged model files to destination...")
|
||||
if is_s3_path(dest_path):
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
copy_local_to_s3(souped_dir, dest_bucket, dest_prefix)
|
||||
copy_local_to_s3(merged_dir, dest_bucket, dest_prefix)
|
||||
else:
|
||||
copy_local_to_local(souped_dir, dest_path)
|
||||
copy_local_to_local(merged_dir, dest_path)
|
||||
else:
|
||||
architectures = [info["architecture"] for info in source_infos]
|
||||
if len(set(architectures)) > 1:
|
||||
raise ValueError("All sources must have the same architecture")
|
||||
|
||||
final_architecture = architectures[0]
|
||||
|
||||
if len(sources) == 1:
|
||||
source_path = sources[0]
|
||||
print("\nCopying model files...")
|
||||
if is_s3_path(source_path) and is_s3_path(dest_path):
|
||||
source_bucket, source_prefix = parse_s3_path(source_path)
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
|
||||
elif is_s3_path(source_path) and not is_s3_path(dest_path):
|
||||
source_bucket, source_prefix = parse_s3_path(source_path)
|
||||
copy_s3_to_local(source_bucket, source_prefix, dest_path)
|
||||
elif not is_s3_path(source_path) and is_s3_path(dest_path):
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
|
||||
else:
|
||||
copy_local_to_local(source_path, dest_path)
|
||||
else:
|
||||
# Souping multiple checkpoints
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Download all sources to local temp dirs
|
||||
source_temps = []
|
||||
for i, source in enumerate(sources):
|
||||
source_temp = os.path.join(temp_dir, f"source_{i}")
|
||||
if is_s3_path(source):
|
||||
bucket, prefix = parse_s3_path(source)
|
||||
copy_s3_to_local(bucket, prefix, source_temp)
|
||||
else:
|
||||
copy_local_to_local(source, source_temp)
|
||||
source_temps.append(source_temp)
|
||||
|
||||
first_source = source_temps[0]
|
||||
|
||||
# Get weight files
|
||||
weight_full_paths = get_weight_files(first_source)
|
||||
weight_rel_paths = [os.path.relpath(p, first_source) for p in weight_full_paths]
|
||||
|
||||
# Verify others have same weight files
|
||||
for i in range(1, len(sources)):
|
||||
other_dir = source_temps[i]
|
||||
other_weights = [os.path.relpath(p, other_dir) for p in get_weight_files(other_dir)]
|
||||
if set(other_weights) != set(weight_rel_paths):
|
||||
raise ValueError(f"Source {sources[i]} has different weight files")
|
||||
|
||||
# Create souped_dir
|
||||
souped_dir = os.path.join(temp_dir, "souped")
|
||||
# Copy first source (including its weights, which will be overwritten)
|
||||
copy_local_to_local(first_source, souped_dir)
|
||||
|
||||
# Average weights
|
||||
for rel_path in tqdm(weight_rel_paths, desc="Averaging weight files"):
|
||||
all_paths = [os.path.join(st, rel_path) for st in source_temps]
|
||||
file_path = all_paths[0]
|
||||
souped_path = os.path.join(souped_dir, rel_path)
|
||||
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
|
||||
|
||||
if file_path.endswith(".safetensors"):
|
||||
sum_state = load_file(file_path, device="cpu")
|
||||
# Store original dtypes for each tensor
|
||||
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||
# Upconvert to at least fp32 for accurate averaging
|
||||
for k in sum_state:
|
||||
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
sum_state[k] = sum_state[k].to(torch.float32)
|
||||
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = load_file(other_path, device="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
# Upconvert other state to match sum_state dtype
|
||||
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
other_state[k] = other_state[k].to(torch.float32)
|
||||
sum_state[k] += other_state[k]
|
||||
del other_state
|
||||
|
||||
n = len(all_paths)
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
# Cast back to original dtype
|
||||
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||
save_file(sum_state, souped_path)
|
||||
elif file_path.endswith(".bin"):
|
||||
sum_state = torch.load(file_path, map_location="cpu")
|
||||
# Store original dtypes for each tensor
|
||||
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||
# Upconvert to at least fp32 for accurate averaging
|
||||
for k in sum_state:
|
||||
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
sum_state[k] = sum_state[k].to(torch.float32)
|
||||
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = torch.load(other_path, map_location="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
# Upconvert other state to match sum_state dtype
|
||||
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
other_state[k] = other_state[k].to(torch.float32)
|
||||
sum_state[k] += other_state[k]
|
||||
del other_state
|
||||
|
||||
n = len(all_paths)
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
# Cast back to original dtype
|
||||
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||
torch.save(sum_state, souped_path)
|
||||
else:
|
||||
print(f"Skipping unknown weight file: {rel_path}")
|
||||
continue
|
||||
|
||||
# Now copy souped_dir to dest_path
|
||||
print("\nCopying souped model files to destination...")
|
||||
if is_s3_path(dest_path):
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
copy_local_to_s3(souped_dir, dest_bucket, dest_prefix)
|
||||
else:
|
||||
copy_local_to_local(souped_dir, dest_path)
|
||||
|
||||
if final_architecture is None:
|
||||
raise ValueError("Unable to determine the architecture of the prepared checkpoint.")
|
||||
|
||||
hf_model_id = HF_MODEL_IDS[final_architecture]
|
||||
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
|
||||
print(f"Using HuggingFace model: {hf_model_id}")
|
||||
|
||||
# Download tokenizer files from Hugging Face
|
||||
print("\nDownloading tokenizer files from Hugging Face...")
|
||||
|
||||
@ -36,6 +36,46 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module:
|
||||
"""Wrap the model with a LoRA adapter according to the configuration."""
|
||||
try:
|
||||
from peft import LoraConfig, get_peft_model
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||
|
||||
lora_kwargs = dict(
|
||||
r=model_cfg.lora_rank,
|
||||
lora_alpha=model_cfg.lora_alpha,
|
||||
lora_dropout=model_cfg.lora_dropout,
|
||||
target_modules=model_cfg.lora_target_modules,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
if model_cfg.lora_modules_to_save:
|
||||
lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save
|
||||
|
||||
lora_config = LoraConfig(**lora_kwargs)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if hasattr(model, "config"):
|
||||
model.config.base_model_name_or_path = model_cfg.name
|
||||
base_model = getattr(model, "base_model", None)
|
||||
if base_model is not None:
|
||||
inner_model = getattr(base_model, "model", None)
|
||||
if inner_model is not None and hasattr(inner_model, "config"):
|
||||
inner_model.config._name_or_path = model_cfg.name
|
||||
|
||||
if hasattr(model, "print_trainable_parameters"):
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def is_lora_checkpoint(checkpoint_dir: str) -> bool:
|
||||
"""Detect whether a checkpoint directory contains LoRA adapter weights."""
|
||||
return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json"))
|
||||
|
||||
|
||||
class QwenDataCollator:
|
||||
"""Data collator for vision-language models that handles numpy arrays."""
|
||||
|
||||
@ -140,9 +180,29 @@ def load_checkpoint(
|
||||
lr_scheduler: Any,
|
||||
checkpoint_dir: str,
|
||||
device: torch.device,
|
||||
*,
|
||||
base_model_path: Optional[str] = None,
|
||||
use_lora: bool = False,
|
||||
) -> tuple[torch.nn.Module, Dict[str, Any]]:
|
||||
"""Load model, optimizer, scheduler, and training state from checkpoint."""
|
||||
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
||||
checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir)
|
||||
|
||||
if checkpoint_has_lora or use_lora:
|
||||
if base_model_path is None:
|
||||
raise ValueError("base_model_path must be provided when loading LoRA checkpoints.")
|
||||
|
||||
try:
|
||||
from peft import PeftModel
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc
|
||||
|
||||
base_model = model_class.from_pretrained(base_model_path, **init_kwargs)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True)
|
||||
if hasattr(model, "config"):
|
||||
model.config.base_model_name_or_path = base_model_path
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
|
||||
@ -280,6 +340,15 @@ def main():
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if config.model.use_lora:
|
||||
logger.info("Applying LoRA adapters as specified in the config.")
|
||||
model = prepare_lora_model(model, config.model)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0
|
||||
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)")
|
||||
|
||||
# Enable gradient checkpointing if configured
|
||||
if config.training.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
||||
@ -370,15 +439,19 @@ def main():
|
||||
logger.info("Model compilation complete")
|
||||
|
||||
# Set up optimizer
|
||||
trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
|
||||
if not trainable_named_params:
|
||||
raise ValueError("No trainable parameters found. Check model fine-tuning configuration.")
|
||||
|
||||
if config.training.optim == "adamw_torch":
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": config.training.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
@ -389,11 +462,14 @@ def main():
|
||||
eps=float(config.training.adam_epsilon),
|
||||
)
|
||||
elif config.training.optim == "muon":
|
||||
if config.model.use_lora:
|
||||
raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.")
|
||||
|
||||
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
|
||||
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
|
||||
hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
|
||||
embed_params = [p for n, p in trainable_named_params if "embed" in n]
|
||||
scalar_params = [p for n, p in trainable_named_params if p.ndim < 2]
|
||||
head_params = [p for n, p in trainable_named_params if "lm_head" in n]
|
||||
|
||||
# Create Adam groups with different learning rates
|
||||
adam_groups = [
|
||||
@ -447,7 +523,16 @@ def main():
|
||||
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
|
||||
|
||||
if found_resumable_checkpoint:
|
||||
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
|
||||
model, state = load_checkpoint(
|
||||
model_class,
|
||||
model_init_kwargs,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
found_resumable_checkpoint,
|
||||
device,
|
||||
base_model_path=config.model.name,
|
||||
use_lora=config.model.use_lora,
|
||||
)
|
||||
global_step = state["global_step"]
|
||||
best_metric = state["best_metric"]
|
||||
samples_seen = state["samples_seen"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user