Merge branch 'jakep/lorafix'

This commit is contained in:
Jake Poznanski 2025-10-30 16:20:53 +00:00
commit 3d2c977ac5
3 changed files with 360 additions and 142 deletions

View File

@ -28,7 +28,7 @@ dataset:
train:
- name: finetuning_data
root_dir: /root/test-berkshire-data
root_dir: /root/test-berkshire-data/train
pipeline: &basic_pipeline
- name: FrontMatterParser
front_matter_class: PageResponse
@ -50,7 +50,7 @@ dataset:
eval:
- name: eval_finetuning_data
root_dir: /root/test-berkshire-data
root_dir: /root/test-berkshire-data/test
pipeline: *basic_pipeline
# Training configuration

View File

@ -36,12 +36,15 @@ import json
import os
import shutil
import tempfile
from typing import Optional
import boto3
import requests
import torch
from botocore.exceptions import ClientError
from smart_open import smart_open
from tqdm import tqdm
from transformers import AutoConfig, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
try:
from safetensors.torch import load_file, save_file
@ -59,6 +62,12 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
# Supported model architectures
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
# Map architectures to corresponding model classes
MODEL_CLASS_MAP = {
"Qwen2VLForConditionalGeneration": Qwen2VLForConditionalGeneration,
"Qwen2_5_VLForConditionalGeneration": Qwen2_5_VLForConditionalGeneration,
}
# Files to exclude from copying (training-related files)
# Supports exact matches and glob patterns
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
@ -79,6 +88,42 @@ def is_s3_path(path: str) -> bool:
return path.startswith("s3://")
def join_path(base: str, *parts: str) -> str:
"""Join paths for local and S3-style URIs."""
if not parts:
return base
if is_s3_path(base):
cleaned = [base.rstrip("/")]
for part in parts:
cleaned.append(part.strip("/"))
return "/".join(segment for segment in cleaned if segment)
return os.path.join(base, *parts)
def load_json_if_exists(path: str) -> Optional[dict]:
"""Load JSON from a path if it exists, otherwise return None."""
try:
with smart_open(path, "r") as handle:
return json.load(handle)
except FileNotFoundError:
return None
except ClientError as exc:
error_code = exc.response.get("Error", {}).get("Code")
if error_code in {"NoSuchKey", "404"}:
return None
raise
except OSError as exc:
if "No such file" in str(exc):
return None
raise
def load_adapter_config(source_path: str) -> Optional[dict]:
"""Return the LoRA adapter configuration if present for the given source."""
adapter_config_path = join_path(source_path, "adapter_config.json")
return load_json_if_exists(adapter_config_path)
def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str) -> None:
"""Download a file from Hugging Face model repository."""
url = f"{hf_base_url}/{filename}"
@ -254,6 +299,61 @@ def get_weight_files(dir_path: str) -> list[str]:
return weight_files
def merge_lora_adapter_checkpoint(adapter_dir: str, base_model_name: str, output_dir: str) -> str:
"""Merge a LoRA adapter into its base model and save the merged weights.
Returns:
The detected architecture string of the merged model.
"""
try:
from peft import PeftModel
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("Merging LoRA adapters requires the `peft` package. Install it with `pip install peft`.") from exc
print(f"Merging LoRA adapter from {adapter_dir} with base model '{base_model_name}'...")
base_config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True)
architecture = None
for arch in base_config.architectures or []:
if arch in MODEL_CLASS_MAP:
architecture = arch
break
if architecture is None:
raise ValueError(
f"Base model '{base_model_name}' uses an unsupported architecture: {base_config.architectures}. "
f"Supported architectures: {SUPPORTED_ARCHITECTURES}"
)
model_class = MODEL_CLASS_MAP[architecture]
base_model = model_class.from_pretrained(
base_model_name,
trust_remote_code=True,
torch_dtype="auto",
)
lora_model = PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=False)
merged_model = lora_model.merge_and_unload()
merged_model = merged_model.to("cpu")
if hasattr(merged_model, "config"):
merged_model.config._name_or_path = base_model_name
merged_model.config.base_model_name_or_path = base_model_name
os.makedirs(output_dir, exist_ok=True)
merged_model.save_pretrained(output_dir)
print(f"✓ Saved merged model to {output_dir}")
# Explicit cleanup
del merged_model
del lora_model
del base_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return architecture
def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
"""Prepare OlmOCR checkpoint(s) for deployment, with support for souping."""
print(f"Preparing {'souped ' if len(sources) > 1 else ''}checkpoint from {len(sources)} source(s) to {dest_path}")
@ -261,146 +361,179 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
sources = [source.rstrip("/") for source in sources]
dest_path = dest_path.rstrip("/")
# Detect architectures
architectures = []
source_infos = []
for source in sources:
config_path = f"{source}/config.json" if is_s3_path(source) else os.path.join(source, "config.json")
arch = detect_checkpoint_architecture(config_path)
architectures.append(arch)
# Check all same
if len(set(architectures)) > 1:
raise ValueError("All sources must have the same architecture")
architecture = architectures[0]
# Get the appropriate HF model ID and base URL
hf_model_id = HF_MODEL_IDS[architecture]
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
print(f"Using HuggingFace model: {hf_model_id}")
if len(sources) == 1:
source_path = sources[0]
# Single checkpoint: copy as before
print("\nCopying model files...")
if is_s3_path(source_path) and is_s3_path(dest_path):
# S3 to S3
source_bucket, source_prefix = parse_s3_path(source_path)
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
elif is_s3_path(source_path) and not is_s3_path(dest_path):
# S3 to local
source_bucket, source_prefix = parse_s3_path(source_path)
copy_s3_to_local(source_bucket, source_prefix, dest_path)
elif not is_s3_path(source_path) and is_s3_path(dest_path):
# Local to S3
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
adapter_config = load_adapter_config(source)
if adapter_config is not None:
source_infos.append({"path": source, "is_lora": True, "adapter_config": adapter_config})
else:
# Local to local
copy_local_to_local(source_path, dest_path)
else:
# Souping multiple checkpoints
config_path = join_path(source, "config.json")
arch = detect_checkpoint_architecture(config_path)
source_infos.append({"path": source, "is_lora": False, "architecture": arch})
num_lora_sources = sum(1 for info in source_infos if info["is_lora"])
final_architecture: Optional[str] = None
if num_lora_sources > 0:
if len(source_infos) > 1:
raise ValueError("LoRA adapter checkpoints can only be processed individually, not during souping.")
source_info = source_infos[0]
source_path = source_info["path"]
adapter_config = source_info["adapter_config"]
base_model_name = adapter_config.get("base_model_name_or_path")
if not base_model_name:
raise ValueError("adapter_config.json is missing 'base_model_name_or_path'; cannot merge LoRA adapter.")
with tempfile.TemporaryDirectory() as temp_dir:
# Download all sources to local temp dirs
source_temps = []
for i, source in enumerate(sources):
source_temp = os.path.join(temp_dir, f"source_{i}")
if is_s3_path(source):
bucket, prefix = parse_s3_path(source)
copy_s3_to_local(bucket, prefix, source_temp)
else:
copy_local_to_local(source, source_temp)
source_temps.append(source_temp)
adapter_local_dir = os.path.join(temp_dir, "adapter")
print("\nDownloading LoRA adapter locally for merging...")
if is_s3_path(source_path):
bucket, prefix = parse_s3_path(source_path)
copy_s3_to_local(bucket, prefix, adapter_local_dir)
else:
copy_local_to_local(source_path, adapter_local_dir)
first_source = source_temps[0]
merged_dir = os.path.join(temp_dir, "merged")
final_architecture = merge_lora_adapter_checkpoint(adapter_local_dir, base_model_name, merged_dir)
# Get weight files
weight_full_paths = get_weight_files(first_source)
weight_rel_paths = [os.path.relpath(p, first_source) for p in weight_full_paths]
# Verify others have same weight files
for i in range(1, len(sources)):
other_dir = source_temps[i]
other_weights = [os.path.relpath(p, other_dir) for p in get_weight_files(other_dir)]
if set(other_weights) != set(weight_rel_paths):
raise ValueError(f"Source {sources[i]} has different weight files")
# Create souped_dir
souped_dir = os.path.join(temp_dir, "souped")
# Copy first source (including its weights, which will be overwritten)
copy_local_to_local(first_source, souped_dir)
# Average weights
for rel_path in tqdm(weight_rel_paths, desc="Averaging weight files"):
all_paths = [os.path.join(st, rel_path) for st in source_temps]
file_path = all_paths[0]
souped_path = os.path.join(souped_dir, rel_path)
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
if file_path.endswith(".safetensors"):
sum_state = load_file(file_path, device="cpu")
# Store original dtypes for each tensor
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
# Upconvert to at least fp32 for accurate averaging
for k in sum_state:
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
sum_state[k] = sum_state[k].to(torch.float32)
for other_path in all_paths[1:]:
other_state = load_file(other_path, device="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
# Upconvert other state to match sum_state dtype
if other_state[k].dtype in (torch.float16, torch.bfloat16):
other_state[k] = other_state[k].to(torch.float32)
sum_state[k] += other_state[k]
del other_state
n = len(all_paths)
for k in sum_state:
sum_state[k] /= n
# Cast back to original dtype
sum_state[k] = sum_state[k].to(original_dtypes[k])
save_file(sum_state, souped_path)
elif file_path.endswith(".bin"):
sum_state = torch.load(file_path, map_location="cpu")
# Store original dtypes for each tensor
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
# Upconvert to at least fp32 for accurate averaging
for k in sum_state:
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
sum_state[k] = sum_state[k].to(torch.float32)
for other_path in all_paths[1:]:
other_state = torch.load(other_path, map_location="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
# Upconvert other state to match sum_state dtype
if other_state[k].dtype in (torch.float16, torch.bfloat16):
other_state[k] = other_state[k].to(torch.float32)
sum_state[k] += other_state[k]
del other_state
n = len(all_paths)
for k in sum_state:
sum_state[k] /= n
# Cast back to original dtype
sum_state[k] = sum_state[k].to(original_dtypes[k])
torch.save(sum_state, souped_path)
else:
print(f"Skipping unknown weight file: {rel_path}")
continue
# Now copy souped_dir to dest_path
print("\nCopying souped model files to destination...")
print("\nCopying merged model files to destination...")
if is_s3_path(dest_path):
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_local_to_s3(souped_dir, dest_bucket, dest_prefix)
copy_local_to_s3(merged_dir, dest_bucket, dest_prefix)
else:
copy_local_to_local(souped_dir, dest_path)
copy_local_to_local(merged_dir, dest_path)
else:
architectures = [info["architecture"] for info in source_infos]
if len(set(architectures)) > 1:
raise ValueError("All sources must have the same architecture")
final_architecture = architectures[0]
if len(sources) == 1:
source_path = sources[0]
print("\nCopying model files...")
if is_s3_path(source_path) and is_s3_path(dest_path):
source_bucket, source_prefix = parse_s3_path(source_path)
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
elif is_s3_path(source_path) and not is_s3_path(dest_path):
source_bucket, source_prefix = parse_s3_path(source_path)
copy_s3_to_local(source_bucket, source_prefix, dest_path)
elif not is_s3_path(source_path) and is_s3_path(dest_path):
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
else:
copy_local_to_local(source_path, dest_path)
else:
# Souping multiple checkpoints
with tempfile.TemporaryDirectory() as temp_dir:
# Download all sources to local temp dirs
source_temps = []
for i, source in enumerate(sources):
source_temp = os.path.join(temp_dir, f"source_{i}")
if is_s3_path(source):
bucket, prefix = parse_s3_path(source)
copy_s3_to_local(bucket, prefix, source_temp)
else:
copy_local_to_local(source, source_temp)
source_temps.append(source_temp)
first_source = source_temps[0]
# Get weight files
weight_full_paths = get_weight_files(first_source)
weight_rel_paths = [os.path.relpath(p, first_source) for p in weight_full_paths]
# Verify others have same weight files
for i in range(1, len(sources)):
other_dir = source_temps[i]
other_weights = [os.path.relpath(p, other_dir) for p in get_weight_files(other_dir)]
if set(other_weights) != set(weight_rel_paths):
raise ValueError(f"Source {sources[i]} has different weight files")
# Create souped_dir
souped_dir = os.path.join(temp_dir, "souped")
# Copy first source (including its weights, which will be overwritten)
copy_local_to_local(first_source, souped_dir)
# Average weights
for rel_path in tqdm(weight_rel_paths, desc="Averaging weight files"):
all_paths = [os.path.join(st, rel_path) for st in source_temps]
file_path = all_paths[0]
souped_path = os.path.join(souped_dir, rel_path)
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
if file_path.endswith(".safetensors"):
sum_state = load_file(file_path, device="cpu")
# Store original dtypes for each tensor
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
# Upconvert to at least fp32 for accurate averaging
for k in sum_state:
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
sum_state[k] = sum_state[k].to(torch.float32)
for other_path in all_paths[1:]:
other_state = load_file(other_path, device="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
# Upconvert other state to match sum_state dtype
if other_state[k].dtype in (torch.float16, torch.bfloat16):
other_state[k] = other_state[k].to(torch.float32)
sum_state[k] += other_state[k]
del other_state
n = len(all_paths)
for k in sum_state:
sum_state[k] /= n
# Cast back to original dtype
sum_state[k] = sum_state[k].to(original_dtypes[k])
save_file(sum_state, souped_path)
elif file_path.endswith(".bin"):
sum_state = torch.load(file_path, map_location="cpu")
# Store original dtypes for each tensor
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
# Upconvert to at least fp32 for accurate averaging
for k in sum_state:
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
sum_state[k] = sum_state[k].to(torch.float32)
for other_path in all_paths[1:]:
other_state = torch.load(other_path, map_location="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
# Upconvert other state to match sum_state dtype
if other_state[k].dtype in (torch.float16, torch.bfloat16):
other_state[k] = other_state[k].to(torch.float32)
sum_state[k] += other_state[k]
del other_state
n = len(all_paths)
for k in sum_state:
sum_state[k] /= n
# Cast back to original dtype
sum_state[k] = sum_state[k].to(original_dtypes[k])
torch.save(sum_state, souped_path)
else:
print(f"Skipping unknown weight file: {rel_path}")
continue
# Now copy souped_dir to dest_path
print("\nCopying souped model files to destination...")
if is_s3_path(dest_path):
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_local_to_s3(souped_dir, dest_bucket, dest_prefix)
else:
copy_local_to_local(souped_dir, dest_path)
if final_architecture is None:
raise ValueError("Unable to determine the architecture of the prepared checkpoint.")
hf_model_id = HF_MODEL_IDS[final_architecture]
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
print(f"Using HuggingFace model: {hf_model_id}")
# Download tokenizer files from Hugging Face
print("\nDownloading tokenizer files from Hugging Face...")

View File

@ -36,6 +36,46 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module:
"""Wrap the model with a LoRA adapter according to the configuration."""
try:
from peft import LoraConfig, get_peft_model
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc
lora_kwargs = dict(
r=model_cfg.lora_rank,
lora_alpha=model_cfg.lora_alpha,
lora_dropout=model_cfg.lora_dropout,
target_modules=model_cfg.lora_target_modules,
bias="none",
task_type="CAUSAL_LM",
)
if model_cfg.lora_modules_to_save:
lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save
lora_config = LoraConfig(**lora_kwargs)
model = get_peft_model(model, lora_config)
if hasattr(model, "config"):
model.config.base_model_name_or_path = model_cfg.name
base_model = getattr(model, "base_model", None)
if base_model is not None:
inner_model = getattr(base_model, "model", None)
if inner_model is not None and hasattr(inner_model, "config"):
inner_model.config._name_or_path = model_cfg.name
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
return model
def is_lora_checkpoint(checkpoint_dir: str) -> bool:
"""Detect whether a checkpoint directory contains LoRA adapter weights."""
return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json"))
class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays."""
@ -140,9 +180,29 @@ def load_checkpoint(
lr_scheduler: Any,
checkpoint_dir: str,
device: torch.device,
*,
base_model_path: Optional[str] = None,
use_lora: bool = False,
) -> tuple[torch.nn.Module, Dict[str, Any]]:
"""Load model, optimizer, scheduler, and training state from checkpoint."""
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir)
if checkpoint_has_lora or use_lora:
if base_model_path is None:
raise ValueError("base_model_path must be provided when loading LoRA checkpoints.")
try:
from peft import PeftModel
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc
base_model = model_class.from_pretrained(base_model_path, **init_kwargs)
model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True)
if hasattr(model, "config"):
model.config.base_model_name_or_path = base_model_path
else:
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
model.to(device)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
@ -280,6 +340,15 @@ def main():
else:
raise NotImplementedError()
if config.model.use_lora:
logger.info("Applying LoRA adapters as specified in the config.")
model = prepare_lora_model(model, config.model)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)")
# Enable gradient checkpointing if configured
if config.training.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
@ -370,15 +439,19 @@ def main():
logger.info("Model compilation complete")
# Set up optimizer
trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
if not trainable_named_params:
raise ValueError("No trainable parameters found. Check model fine-tuning configuration.")
if config.training.optim == "adamw_torch":
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
@ -389,11 +462,14 @@ def main():
eps=float(config.training.adam_epsilon),
)
elif config.training.optim == "muon":
if config.model.use_lora:
raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.")
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in trainable_named_params if "embed" in n]
scalar_params = [p for n, p in trainable_named_params if p.ndim < 2]
head_params = [p for n, p in trainable_named_params if "lm_head" in n]
# Create Adam groups with different learning rates
adam_groups = [
@ -447,7 +523,16 @@ def main():
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
if found_resumable_checkpoint:
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
model, state = load_checkpoint(
model_class,
model_init_kwargs,
optimizer,
lr_scheduler,
found_resumable_checkpoint,
device,
base_model_path=config.model.name,
use_lora=config.model.use_lora,
)
global_step = state["global_step"]
best_metric = state["best_metric"]
samples_seen = state["samples_seen"]