mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-20 23:02:27 +00:00
214 lines
7.2 KiB
Python
214 lines
7.2 KiB
Python
import json
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from hashlib import sha1
|
|
from logging import Logger
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Dict, Generator, List, Optional, TypeVar
|
|
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import AutoProcessor
|
|
from accelerate import Accelerator
|
|
from accelerate.utils import PrecisionType
|
|
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
|
|
|
from .core.cli import to_native_types
|
|
from .core.config import AwsConfig, TrainConfig, WandbConfig, DataConfig, SourceConfig
|
|
from .core.loggers import get_logger
|
|
from .core.paths import copy_dir, is_local
|
|
from .core.state import BeakerState
|
|
# from .tokenization import ModelTokenizer
|
|
|
|
T = TypeVar("T")
|
|
|
|
from pdelfin.train.dataloader import build_finetuning_dataset, list_dataset_files
|
|
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
|
|
|
|
|
def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
|
pt = PrecisionType(accelerator.mixed_precision)
|
|
if pt == PrecisionType.FP16:
|
|
return torch.float16
|
|
elif pt == PrecisionType.BF16:
|
|
return torch.bfloat16
|
|
elif pt == PrecisionType.FP8:
|
|
return torch.float8_e4m3fn
|
|
return torch.float32
|
|
|
|
def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset:
|
|
return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location)
|
|
|
|
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
|
|
random.seed(config.train_data.seed)
|
|
|
|
# Retrieve the two target lengths from the first source for comparison
|
|
first_source = config.train_data.sources[0]
|
|
target_longest_image_dim = first_source.target_longest_image_dim
|
|
target_anchor_text_len = first_source.target_anchor_text_len
|
|
|
|
# Verify that all sources have the same target lengths
|
|
for source in config.train_data.sources:
|
|
if source.target_longest_image_dim != target_longest_image_dim:
|
|
raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}")
|
|
if source.target_anchor_text_len != target_anchor_text_len:
|
|
raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}")
|
|
|
|
|
|
# Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library
|
|
train_dataset = concatenate_datasets(
|
|
[
|
|
get_rawdataset_from_source(config.train_data, source)
|
|
for source in config.train_data.sources
|
|
]
|
|
)
|
|
|
|
# Apply the transform to the concatenated dataset
|
|
train_dataset = train_dataset.with_transform(
|
|
partial(
|
|
batch_prepare_data_for_qwen2_training,
|
|
processor=processor,
|
|
target_longest_image_dim=target_longest_image_dim,
|
|
target_anchor_text_len=target_anchor_text_len,
|
|
)
|
|
)
|
|
|
|
# Validation sets get put into a datasetdict so each can report a loss separately
|
|
valid_dataset = DatasetDict(
|
|
**{
|
|
source.name: get_rawdataset_from_source(config.valid_data, source).with_transform(
|
|
partial(
|
|
batch_prepare_data_for_qwen2_training,
|
|
processor=processor,
|
|
target_longest_image_dim=source.target_longest_image_dim,
|
|
target_anchor_text_len=source.target_anchor_text_len,
|
|
)
|
|
)
|
|
for source in config.valid_data.sources
|
|
}
|
|
)
|
|
|
|
return train_dataset, valid_dataset
|
|
|
|
|
|
def setup_environment(
|
|
aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str
|
|
):
|
|
multiprocessing.set_start_method("spawn", force=True)
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "false"
|
|
|
|
if wandb_config:
|
|
os.environ["WANDB_WATCH"] = "false"
|
|
|
|
for key, value in to_native_types(wandb_config or {}).items():
|
|
if value is not None:
|
|
os.environ[f"WANDB_{key.upper()}"] = str(value)
|
|
|
|
for key, value in to_native_types(aws_config or {}).items():
|
|
if value is not None:
|
|
os.environ[f"AWS_{key.upper()}"] = str(value)
|
|
|
|
os.environ.update(kwargs)
|
|
|
|
|
|
@dataclass
|
|
class RunName:
|
|
run: str
|
|
group: str
|
|
|
|
@classmethod
|
|
def get(cls, config: TrainConfig, accelerator: Optional[Accelerator] = None) -> "RunName":
|
|
job_rank = f"-{accelerator.process_index}" if accelerator else ""
|
|
|
|
if beaker_job_id := BeakerState().job_id:
|
|
job_id = f"-{beaker_job_id}"
|
|
else:
|
|
job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
(config_hash := sha1()).update(json.dumps(to_native_types(config)).encode())
|
|
model_name = config.model.name_or_path.replace("/", "_")
|
|
|
|
group_name = f"{model_name}-{config_hash.hexdigest()[:6]}"
|
|
run_name = f"{group_name}{job_id}{job_rank}"
|
|
return cls(group=group_name, run=run_name)
|
|
|
|
|
|
@contextmanager
|
|
def override_torch_threads(n: int):
|
|
torch_num_threads = torch.get_num_threads()
|
|
torch.set_num_threads(n)
|
|
|
|
yield
|
|
|
|
torch.set_num_threads(torch_num_threads)
|
|
|
|
|
|
@contextmanager
|
|
def temp_args(obj: T, **kwargs) -> Generator[T, None, None]:
|
|
orig = {k: getattr(obj, k) for k in kwargs.keys()}
|
|
for k, v in kwargs.items():
|
|
setattr(obj, k, v)
|
|
|
|
yield obj
|
|
|
|
for k, v in orig.items():
|
|
setattr(obj, k, v)
|
|
|
|
|
|
def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] = None):
|
|
"""
|
|
Prints the number of trainable parameters in the model.
|
|
"""
|
|
trainable_params = 0
|
|
all_param = 0
|
|
for _, param in model.named_parameters():
|
|
all_param += param.numel()
|
|
if param.requires_grad:
|
|
trainable_params += param.numel()
|
|
|
|
(logger or get_logger(__name__)).info(
|
|
"trainable params: %s || all params: %s || trainable%%: %s",
|
|
f"{trainable_params:,}",
|
|
f"{all_param:,}",
|
|
f"{trainable_params / all_param:.2%}",
|
|
)
|
|
|
|
|
|
class TruncatingCollator:
|
|
def __init__(self, max_length: int):
|
|
self.max_length = max_length
|
|
|
|
def __call__(self, batch: List[Dict]) -> Dict:
|
|
# Assert that we are only handling batch size 1 for now
|
|
assert len(batch) == 1, "Only batch size 1 is supported for now"
|
|
|
|
truncated_input_ids = torch.tensor(batch[0]["input_ids"][:self.max_length]).unsqueeze(0)
|
|
truncated_attention_mask = torch.tensor(batch[0]["attention_mask"][:self.max_length]).unsqueeze(0)
|
|
truncated_labels = torch.tensor(batch[0]["labels"][:self.max_length]).unsqueeze(0)
|
|
|
|
return {
|
|
"input_ids": truncated_input_ids,
|
|
"attention_mask": truncated_attention_mask,
|
|
"labels": truncated_labels,
|
|
"pixel_values": torch.tensor(batch[0]["pixel_values"]).unsqueeze(0),
|
|
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]).unsqueeze(0),
|
|
}
|
|
|
|
|
|
@contextmanager
|
|
def get_local_dir(output_dir: str):
|
|
|
|
with TemporaryDirectory() as tmp_dir:
|
|
if is_local(output_dir):
|
|
yield output_dir
|
|
else:
|
|
yield tmp_dir
|
|
copy_dir(tmp_dir, output_dir)
|