diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index c9a6ecd..982a93c 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -116,9 +116,11 @@ def run_train(config: TrainConfig): setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) accelerator = accelerate.Accelerator() + processor = AutoProcessor.from_pretrained(config.model.name_or_path) + # Build and download the dataset on process 0 if accelerator.is_main_process: - make_dataset(config) + make_dataset(config, processor) accelerator.wait_for_everyone() @@ -128,8 +130,7 @@ def run_train(config: TrainConfig): config.model.name_or_path, torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None ) - processor = AutoProcessor.from_pretrained(config.model.name_or_path) - + if config.lora is not None: peft_config = LoraConfig( r=config.lora.rank, diff --git a/pdelfin/train/utils.py b/pdelfin/train/utils.py index 5ed1b96..5bc2d12 100644 --- a/pdelfin/train/utils.py +++ b/pdelfin/train/utils.py @@ -14,6 +14,7 @@ from functools import partial import torch import torch.nn.functional as F +from transformers import AutoProcessor from accelerate import Accelerator from accelerate.utils import PrecisionType from datasets import Dataset, concatenate_datasets, DatasetDict @@ -41,7 +42,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: return torch.float8_e4m3fn return torch.float32 -def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]: +def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]: random.seed(config.train_data.seed) # Training sets get all concatenated and shuffled