Merge branch 'jakep/lorafix'

This commit is contained in:
Jake Poznanski 2025-10-30 16:20:53 +00:00
commit 3d2c977ac5
3 changed files with 360 additions and 142 deletions

View File

@ -28,7 +28,7 @@ dataset:
train: train:
- name: finetuning_data - name: finetuning_data
root_dir: /root/test-berkshire-data root_dir: /root/test-berkshire-data/train
pipeline: &basic_pipeline pipeline: &basic_pipeline
- name: FrontMatterParser - name: FrontMatterParser
front_matter_class: PageResponse front_matter_class: PageResponse
@ -50,7 +50,7 @@ dataset:
eval: eval:
- name: eval_finetuning_data - name: eval_finetuning_data
root_dir: /root/test-berkshire-data root_dir: /root/test-berkshire-data/test
pipeline: *basic_pipeline pipeline: *basic_pipeline
# Training configuration # Training configuration

View File

@ -36,12 +36,15 @@ import json
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import Optional
import boto3 import boto3
import requests import requests
import torch import torch
from botocore.exceptions import ClientError
from smart_open import smart_open from smart_open import smart_open
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
try: try:
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@ -59,6 +62,12 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
# Supported model architectures # Supported model architectures
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"] SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
# Map architectures to corresponding model classes
MODEL_CLASS_MAP = {
"Qwen2VLForConditionalGeneration": Qwen2VLForConditionalGeneration,
"Qwen2_5_VLForConditionalGeneration": Qwen2_5_VLForConditionalGeneration,
}
# Files to exclude from copying (training-related files) # Files to exclude from copying (training-related files)
# Supports exact matches and glob patterns # Supports exact matches and glob patterns
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"} EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
@ -79,6 +88,42 @@ def is_s3_path(path: str) -> bool:
return path.startswith("s3://") return path.startswith("s3://")
def join_path(base: str, *parts: str) -> str:
"""Join paths for local and S3-style URIs."""
if not parts:
return base
if is_s3_path(base):
cleaned = [base.rstrip("/")]
for part in parts:
cleaned.append(part.strip("/"))
return "/".join(segment for segment in cleaned if segment)
return os.path.join(base, *parts)
def load_json_if_exists(path: str) -> Optional[dict]:
"""Load JSON from a path if it exists, otherwise return None."""
try:
with smart_open(path, "r") as handle:
return json.load(handle)
except FileNotFoundError:
return None
except ClientError as exc:
error_code = exc.response.get("Error", {}).get("Code")
if error_code in {"NoSuchKey", "404"}:
return None
raise
except OSError as exc:
if "No such file" in str(exc):
return None
raise
def load_adapter_config(source_path: str) -> Optional[dict]:
"""Return the LoRA adapter configuration if present for the given source."""
adapter_config_path = join_path(source_path, "adapter_config.json")
return load_json_if_exists(adapter_config_path)
def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str) -> None: def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str) -> None:
"""Download a file from Hugging Face model repository.""" """Download a file from Hugging Face model repository."""
url = f"{hf_base_url}/{filename}" url = f"{hf_base_url}/{filename}"
@ -254,6 +299,61 @@ def get_weight_files(dir_path: str) -> list[str]:
return weight_files return weight_files
def merge_lora_adapter_checkpoint(adapter_dir: str, base_model_name: str, output_dir: str) -> str:
"""Merge a LoRA adapter into its base model and save the merged weights.
Returns:
The detected architecture string of the merged model.
"""
try:
from peft import PeftModel
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("Merging LoRA adapters requires the `peft` package. Install it with `pip install peft`.") from exc
print(f"Merging LoRA adapter from {adapter_dir} with base model '{base_model_name}'...")
base_config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True)
architecture = None
for arch in base_config.architectures or []:
if arch in MODEL_CLASS_MAP:
architecture = arch
break
if architecture is None:
raise ValueError(
f"Base model '{base_model_name}' uses an unsupported architecture: {base_config.architectures}. "
f"Supported architectures: {SUPPORTED_ARCHITECTURES}"
)
model_class = MODEL_CLASS_MAP[architecture]
base_model = model_class.from_pretrained(
base_model_name,
trust_remote_code=True,
torch_dtype="auto",
)
lora_model = PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=False)
merged_model = lora_model.merge_and_unload()
merged_model = merged_model.to("cpu")
if hasattr(merged_model, "config"):
merged_model.config._name_or_path = base_model_name
merged_model.config.base_model_name_or_path = base_model_name
os.makedirs(output_dir, exist_ok=True)
merged_model.save_pretrained(output_dir)
print(f"✓ Saved merged model to {output_dir}")
# Explicit cleanup
del merged_model
del lora_model
del base_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return architecture
def prepare_checkpoints(sources: list[str], dest_path: str) -> None: def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
"""Prepare OlmOCR checkpoint(s) for deployment, with support for souping.""" """Prepare OlmOCR checkpoint(s) for deployment, with support for souping."""
print(f"Preparing {'souped ' if len(sources) > 1 else ''}checkpoint from {len(sources)} source(s) to {dest_path}") print(f"Preparing {'souped ' if len(sources) > 1 else ''}checkpoint from {len(sources)} source(s) to {dest_path}")
@ -261,43 +361,69 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
sources = [source.rstrip("/") for source in sources] sources = [source.rstrip("/") for source in sources]
dest_path = dest_path.rstrip("/") dest_path = dest_path.rstrip("/")
# Detect architectures source_infos = []
architectures = []
for source in sources: for source in sources:
config_path = f"{source}/config.json" if is_s3_path(source) else os.path.join(source, "config.json") adapter_config = load_adapter_config(source)
if adapter_config is not None:
source_infos.append({"path": source, "is_lora": True, "adapter_config": adapter_config})
else:
config_path = join_path(source, "config.json")
arch = detect_checkpoint_architecture(config_path) arch = detect_checkpoint_architecture(config_path)
architectures.append(arch) source_infos.append({"path": source, "is_lora": False, "architecture": arch})
# Check all same num_lora_sources = sum(1 for info in source_infos if info["is_lora"])
final_architecture: Optional[str] = None
if num_lora_sources > 0:
if len(source_infos) > 1:
raise ValueError("LoRA adapter checkpoints can only be processed individually, not during souping.")
source_info = source_infos[0]
source_path = source_info["path"]
adapter_config = source_info["adapter_config"]
base_model_name = adapter_config.get("base_model_name_or_path")
if not base_model_name:
raise ValueError("adapter_config.json is missing 'base_model_name_or_path'; cannot merge LoRA adapter.")
with tempfile.TemporaryDirectory() as temp_dir:
adapter_local_dir = os.path.join(temp_dir, "adapter")
print("\nDownloading LoRA adapter locally for merging...")
if is_s3_path(source_path):
bucket, prefix = parse_s3_path(source_path)
copy_s3_to_local(bucket, prefix, adapter_local_dir)
else:
copy_local_to_local(source_path, adapter_local_dir)
merged_dir = os.path.join(temp_dir, "merged")
final_architecture = merge_lora_adapter_checkpoint(adapter_local_dir, base_model_name, merged_dir)
print("\nCopying merged model files to destination...")
if is_s3_path(dest_path):
dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_local_to_s3(merged_dir, dest_bucket, dest_prefix)
else:
copy_local_to_local(merged_dir, dest_path)
else:
architectures = [info["architecture"] for info in source_infos]
if len(set(architectures)) > 1: if len(set(architectures)) > 1:
raise ValueError("All sources must have the same architecture") raise ValueError("All sources must have the same architecture")
architecture = architectures[0] final_architecture = architectures[0]
# Get the appropriate HF model ID and base URL
hf_model_id = HF_MODEL_IDS[architecture]
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
print(f"Using HuggingFace model: {hf_model_id}")
if len(sources) == 1: if len(sources) == 1:
source_path = sources[0] source_path = sources[0]
# Single checkpoint: copy as before
print("\nCopying model files...") print("\nCopying model files...")
if is_s3_path(source_path) and is_s3_path(dest_path): if is_s3_path(source_path) and is_s3_path(dest_path):
# S3 to S3
source_bucket, source_prefix = parse_s3_path(source_path) source_bucket, source_prefix = parse_s3_path(source_path)
dest_bucket, dest_prefix = parse_s3_path(dest_path) dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix) copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
elif is_s3_path(source_path) and not is_s3_path(dest_path): elif is_s3_path(source_path) and not is_s3_path(dest_path):
# S3 to local
source_bucket, source_prefix = parse_s3_path(source_path) source_bucket, source_prefix = parse_s3_path(source_path)
copy_s3_to_local(source_bucket, source_prefix, dest_path) copy_s3_to_local(source_bucket, source_prefix, dest_path)
elif not is_s3_path(source_path) and is_s3_path(dest_path): elif not is_s3_path(source_path) and is_s3_path(dest_path):
# Local to S3
dest_bucket, dest_prefix = parse_s3_path(dest_path) dest_bucket, dest_prefix = parse_s3_path(dest_path)
copy_local_to_s3(source_path, dest_bucket, dest_prefix) copy_local_to_s3(source_path, dest_bucket, dest_prefix)
else: else:
# Local to local
copy_local_to_local(source_path, dest_path) copy_local_to_local(source_path, dest_path)
else: else:
# Souping multiple checkpoints # Souping multiple checkpoints
@ -402,6 +528,13 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
else: else:
copy_local_to_local(souped_dir, dest_path) copy_local_to_local(souped_dir, dest_path)
if final_architecture is None:
raise ValueError("Unable to determine the architecture of the prepared checkpoint.")
hf_model_id = HF_MODEL_IDS[final_architecture]
hf_base_url = f"https://huggingface.co/{hf_model_id}/resolve/main"
print(f"Using HuggingFace model: {hf_model_id}")
# Download tokenizer files from Hugging Face # Download tokenizer files from Hugging Face
print("\nDownloading tokenizer files from Hugging Face...") print("\nDownloading tokenizer files from Hugging Face...")

View File

@ -36,6 +36,46 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def prepare_lora_model(model: torch.nn.Module, model_cfg) -> torch.nn.Module:
"""Wrap the model with a LoRA adapter according to the configuration."""
try:
from peft import LoraConfig, get_peft_model
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("LoRA training requires the `peft` package. Install it with `pip install peft`.") from exc
lora_kwargs = dict(
r=model_cfg.lora_rank,
lora_alpha=model_cfg.lora_alpha,
lora_dropout=model_cfg.lora_dropout,
target_modules=model_cfg.lora_target_modules,
bias="none",
task_type="CAUSAL_LM",
)
if model_cfg.lora_modules_to_save:
lora_kwargs["modules_to_save"] = model_cfg.lora_modules_to_save
lora_config = LoraConfig(**lora_kwargs)
model = get_peft_model(model, lora_config)
if hasattr(model, "config"):
model.config.base_model_name_or_path = model_cfg.name
base_model = getattr(model, "base_model", None)
if base_model is not None:
inner_model = getattr(base_model, "model", None)
if inner_model is not None and hasattr(inner_model, "config"):
inner_model.config._name_or_path = model_cfg.name
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
return model
def is_lora_checkpoint(checkpoint_dir: str) -> bool:
"""Detect whether a checkpoint directory contains LoRA adapter weights."""
return os.path.exists(os.path.join(checkpoint_dir, "adapter_config.json"))
class QwenDataCollator: class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays.""" """Data collator for vision-language models that handles numpy arrays."""
@ -140,9 +180,29 @@ def load_checkpoint(
lr_scheduler: Any, lr_scheduler: Any,
checkpoint_dir: str, checkpoint_dir: str,
device: torch.device, device: torch.device,
*,
base_model_path: Optional[str] = None,
use_lora: bool = False,
) -> tuple[torch.nn.Module, Dict[str, Any]]: ) -> tuple[torch.nn.Module, Dict[str, Any]]:
"""Load model, optimizer, scheduler, and training state from checkpoint.""" """Load model, optimizer, scheduler, and training state from checkpoint."""
checkpoint_has_lora = is_lora_checkpoint(checkpoint_dir)
if checkpoint_has_lora or use_lora:
if base_model_path is None:
raise ValueError("base_model_path must be provided when loading LoRA checkpoints.")
try:
from peft import PeftModel
except ImportError as exc: # pragma: no cover - optional dependency guard
raise ImportError("Loading a LoRA checkpoint requires the `peft` package. Install it with `pip install peft`.") from exc
base_model = model_class.from_pretrained(base_model_path, **init_kwargs)
model = PeftModel.from_pretrained(base_model, checkpoint_dir, is_trainable=True)
if hasattr(model, "config"):
model.config.base_model_name_or_path = base_model_path
else:
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs) model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
model.to(device) model.to(device)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device)) optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
@ -280,6 +340,15 @@ def main():
else: else:
raise NotImplementedError() raise NotImplementedError()
if config.model.use_lora:
logger.info("Applying LoRA adapters as specified in the config.")
model = prepare_lora_model(model, config.model)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_ratio = (trainable_params / total_params * 100) if total_params else 0.0
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_ratio:.2f}%)")
# Enable gradient checkpointing if configured # Enable gradient checkpointing if configured
if config.training.gradient_checkpointing: if config.training.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
@ -370,15 +439,19 @@ def main():
logger.info("Model compilation complete") logger.info("Model compilation complete")
# Set up optimizer # Set up optimizer
trainable_named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
if not trainable_named_params:
raise ValueError("No trainable parameters found. Check model fine-tuning configuration.")
if config.training.optim == "adamw_torch": if config.training.optim == "adamw_torch":
no_decay = ["bias", "LayerNorm.weight"] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{ {
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "params": [p for n, p in trainable_named_params if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay, "weight_decay": config.training.weight_decay,
}, },
{ {
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "params": [p for n, p in trainable_named_params if any(nd in n for nd in no_decay)],
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
@ -389,11 +462,14 @@ def main():
eps=float(config.training.adam_epsilon), eps=float(config.training.adam_epsilon),
) )
elif config.training.optim == "muon": elif config.training.optim == "muon":
if config.model.use_lora:
raise NotImplementedError("LoRA training is not currently supported with the Muon optimizer in this loop.")
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head) # Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n] hidden_matrix_params = [p for n, p in trainable_named_params if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n] embed_params = [p for n, p in trainable_named_params if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2] scalar_params = [p for n, p in trainable_named_params if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n] head_params = [p for n, p in trainable_named_params if "lm_head" in n]
# Create Adam groups with different learning rates # Create Adam groups with different learning rates
adam_groups = [ adam_groups = [
@ -447,7 +523,16 @@ def main():
best_metric = float("inf") if not config.training.greater_is_better else -float("inf") best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
if found_resumable_checkpoint: if found_resumable_checkpoint:
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device) model, state = load_checkpoint(
model_class,
model_init_kwargs,
optimizer,
lr_scheduler,
found_resumable_checkpoint,
device,
base_model_path=config.model.name,
use_lora=config.model.use_lora,
)
global_step = state["global_step"] global_step = state["global_step"]
best_metric = state["best_metric"] best_metric = state["best_metric"]
samples_seen = state["samples_seen"] samples_seen = state["samples_seen"]