Getting ready to launch a new training run

This commit is contained in:
Jake Poznanski 2024-10-02 23:04:56 +00:00
parent 1686790ac8
commit 0ddaf9023d
4 changed files with 44 additions and 37 deletions

View File

@ -28,21 +28,20 @@ generate:
train_data: train_data:
seed: 1337 seed: 1337
sources: sources:
- name: openai_batch_data_v2 - name: openai_batch_data_v5_1_eval # TODO This is just for testing the job, once ready change to a real train dataset
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
backend:
- openai
size: 100_000
valid_data: valid_data:
sources: sources:
- name: openai_batch_data_eval_mini - name: openai_batch_data_v5_1_eval
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
backend: - name: openai_batch_data_v5_1_iabooks_eval
- openai query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
size: 100_000 response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.jsonl
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh # Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams: hparams:
@ -52,10 +51,10 @@ hparams:
gradient_checkpointing: false gradient_checkpointing: false
clip_grad_norm: 1.0 clip_grad_norm: 1.0
learning_rate: 3e-4 learning_rate: 3e-4
max_steps: 5000 max_steps: 2000
pad_multiple_of: 16 pad_multiple_of: 16
log_every_steps: 50 log_every_steps: 50
eval_every_steps: 500 eval_every_steps: 100
optim: adamw_torch optim: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
weight_decay: 0.01 weight_decay: 0.01

View File

@ -75,10 +75,8 @@ class AwsConfig:
@dataclass @dataclass
class SourceConfig: class SourceConfig:
name: str = field(help="The name of the source") name: str = field(help="The name of the source")
size: int = field(help="Limit size for the source")
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data") query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai") response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
backend: List[str] = field(help="The data generation backend to use to train the model")
@dataclass @dataclass

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
import accelerate import accelerate
import torch import torch
import torch.distributed import torch.distributed
from datasets import DatasetDict from datasets import DatasetDict, concatenate_datasets
from datasets.utils import disable_progress_bars from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore from peft import LoraConfig, get_peft_model # pyright: ignore
@ -49,7 +49,7 @@ from .utils import (
) )
from pdelfin.train.dataloader import make_dataset from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
@ -113,13 +113,6 @@ def run_train(config: TrainConfig):
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
dataset = make_dataset(
train_data_config=config.train_data,
valid_data_config=config.valid_data,
num_proc=config.num_proc,
logger=logger,
)
model = Qwen2VLForConditionalGeneration.from_pretrained( model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16, config.model.name_or_path, torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
@ -139,10 +132,28 @@ def run_train(config: TrainConfig):
log_trainable_parameters(model=model, logger=logger) log_trainable_parameters(model=model, logger=logger)
# Do final filtering, and prep for running model forward() # Do final filtering, and prep for running model forward()
filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset})
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) # Training sets get all concatenated and shuffled
print(formatted_dataset) train_dataset = (
print("---------------") concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
)
save_path = join_path("", config.save.path, run_name.run) save_path = join_path("", config.save.path, run_name.run)
@ -192,8 +203,8 @@ def run_train(config: TrainConfig):
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=formatted_dataset["train"], train_dataset=train_dataset,
eval_dataset=formatted_dataset["validation"], # pyright: ignore eval_dataset=valid_dataset,
tokenizer=processor.tokenizer, tokenizer=processor.tokenizer,
#Collator is not needed as we are doing batch size 1 for now... #Collator is not needed as we are doing batch size 1 for now...
#data_collator=collator, #data_collator=collator,
@ -218,7 +229,6 @@ def run_train(config: TrainConfig):
logger.info("Saved best model to %s", best_dir) logger.info("Saved best model to %s", best_dir)
# Uncomment to test speed of data loader # Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False) # train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader): # for entry in tqdm(train_dataloader):

View File

@ -19,7 +19,7 @@ run_name=$(basename "$0" .sh)
# --cluster 'ai2/allennlp-cirrascale' \ # --cluster 'ai2/allennlp-cirrascale' \
# --priority high \ # --priority high \
CLUSTER='jupiter' CLUSTER='pluto'
gantry run \ gantry run \
--description "${run_name}"\ --description "${run_name}"\