import json import multiprocessing import os import random from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from hashlib import sha1 from logging import Logger from tempfile import TemporaryDirectory from typing import Dict, Generator, List, Optional, TypeVar from functools import partial import torch import torch.nn.functional as F from transformers import AutoProcessor from accelerate import Accelerator from accelerate.utils import PrecisionType from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset from .core.cli import to_native_types from .core.config import AwsConfig, TrainConfig, WandbConfig, DataConfig, SourceConfig from .core.loggers import get_logger from .core.paths import copy_dir, is_local from .core.state import BeakerState # from .tokenization import ModelTokenizer T = TypeVar("T") from pdelfin.train.dataloader import build_finetuning_dataset, list_dataset_files from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: pt = PrecisionType(accelerator.mixed_precision) if pt == PrecisionType.FP16: return torch.float16 elif pt == PrecisionType.BF16: return torch.bfloat16 elif pt == PrecisionType.FP8: return torch.float8_e4m3fn return torch.float32 def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset: return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location) def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]: random.seed(config.train_data.seed) # Retrieve the two target lengths from the first source for comparison first_source = config.train_data.sources[0] target_longest_image_dim = first_source.target_longest_image_dim target_anchor_text_len = first_source.target_anchor_text_len # Verify that all sources have the same target lengths for source in config.train_data.sources: if source.target_longest_image_dim != target_longest_image_dim: raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}") if source.target_anchor_text_len != target_anchor_text_len: raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}") # Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library train_dataset = concatenate_datasets( [ get_rawdataset_from_source(config.train_data, source) for source in config.train_data.sources ] ) # Apply the transform to the concatenated dataset train_dataset = train_dataset.with_transform( partial( batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len, ) ) # Validation sets get put into a datasetdict so each can report a loss separately valid_dataset = DatasetDict( **{ source.name: get_rawdataset_from_source(config.valid_data, source).with_transform( partial( batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=source.target_longest_image_dim, target_anchor_text_len=source.target_anchor_text_len, ) ) for source in config.valid_data.sources } ) return train_dataset, valid_dataset def setup_environment( aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str ): multiprocessing.set_start_method("spawn", force=True) os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "false" if wandb_config: os.environ["WANDB_WATCH"] = "false" for key, value in to_native_types(wandb_config or {}).items(): if value is not None: os.environ[f"WANDB_{key.upper()}"] = str(value) for key, value in to_native_types(aws_config or {}).items(): if value is not None: os.environ[f"AWS_{key.upper()}"] = str(value) os.environ.update(kwargs) @dataclass class RunName: run: str group: str @classmethod def get(cls, config: TrainConfig, accelerator: Optional[Accelerator] = None) -> "RunName": job_rank = f"-{accelerator.process_index}" if accelerator else "" if beaker_job_id := BeakerState().job_id: job_id = f"-{beaker_job_id}" else: job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") (config_hash := sha1()).update(json.dumps(to_native_types(config)).encode()) model_name = config.model.name_or_path.replace("/", "_") group_name = f"{model_name}-{config_hash.hexdigest()[:6]}" run_name = f"{group_name}{job_id}{job_rank}" return cls(group=group_name, run=run_name) @contextmanager def override_torch_threads(n: int): torch_num_threads = torch.get_num_threads() torch.set_num_threads(n) yield torch.set_num_threads(torch_num_threads) @contextmanager def temp_args(obj: T, **kwargs) -> Generator[T, None, None]: orig = {k: getattr(obj, k) for k in kwargs.keys()} for k, v in kwargs.items(): setattr(obj, k, v) yield obj for k, v in orig.items(): setattr(obj, k, v) def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] = None): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() (logger or get_logger(__name__)).info( "trainable params: %s || all params: %s || trainable%%: %s", f"{trainable_params:,}", f"{all_param:,}", f"{trainable_params / all_param:.2%}", ) class TruncatingCollator: def __init__(self, max_length: int): self.max_length = max_length def __call__(self, batch: List[Dict]) -> Dict: # Assert that we are only handling batch size 1 for now assert len(batch) == 1, "Only batch size 1 is supported for now" truncated_input_ids = torch.tensor(batch[0]["input_ids"][:self.max_length]).unsqueeze(0) truncated_attention_mask = torch.tensor(batch[0]["attention_mask"][:self.max_length]).unsqueeze(0) truncated_labels = torch.tensor(batch[0]["labels"][:self.max_length]).unsqueeze(0) return { "input_ids": truncated_input_ids, "attention_mask": truncated_attention_mask, "labels": truncated_labels, "pixel_values": torch.tensor(batch[0]["pixel_values"]).unsqueeze(0), "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]).unsqueeze(0), } @contextmanager def get_local_dir(output_dir: str): with TemporaryDirectory() as tmp_dir: if is_local(output_dir): yield output_dir else: yield tmp_dir copy_dir(tmp_dir, output_dir)