This commit is contained in:
Jake Poznanski 2024-10-07 13:03:31 -07:00
parent e973de7ba9
commit 4557a5b296
2 changed files with 6 additions and 4 deletions

View File

@ -116,9 +116,11 @@ def run_train(config: TrainConfig):
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
accelerator = accelerate.Accelerator() accelerator = accelerate.Accelerator()
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
# Build and download the dataset on process 0 # Build and download the dataset on process 0
if accelerator.is_main_process: if accelerator.is_main_process:
make_dataset(config) make_dataset(config, processor)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -128,8 +130,7 @@ def run_train(config: TrainConfig):
config.model.name_or_path, torch_dtype=torch.bfloat16, config.model.name_or_path, torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
) )
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
if config.lora is not None: if config.lora is not None:
peft_config = LoraConfig( peft_config = LoraConfig(
r=config.lora.rank, r=config.lora.rank,

View File

@ -14,6 +14,7 @@ from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoProcessor
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import PrecisionType from accelerate.utils import PrecisionType
from datasets import Dataset, concatenate_datasets, DatasetDict from datasets import Dataset, concatenate_datasets, DatasetDict
@ -41,7 +42,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
return torch.float8_e4m3fn return torch.float8_e4m3fn
return torch.float32 return torch.float32
def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]: def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed) random.seed(config.train_data.seed)
# Training sets get all concatenated and shuffled # Training sets get all concatenated and shuffled