Hoping to get a basic hf Trainer to run

This commit is contained in:
Jake Poznanski 2024-09-20 15:53:11 -07:00
parent 55035b02c9
commit 256d77c232
4 changed files with 204 additions and 5 deletions

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
# ml stuff
wandb/
# build artifacts
.eggs/

View File

@ -0,0 +1,124 @@
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
wandb:
project: refine
entity: pdf-qwen2vl
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
train_data:
seed: 1337
sources:
- name: fw-edu-all
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-all/*.json.gz
backend:
- openai
size: 100_000
- name: dclm
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
backend:
- openai
size: 100_000
- name: dolma-v17
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
backend:
- openai
size: 100_000
- name: dolma-v1-small
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
backend:
- openai
size: 100_000
valid_data:
sources:
- name: fw-edu-10k
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-10k/valid/*.gz
backend:
- openai
size: 1500
- name: dolma-10k
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-10k/valid/*.gz
backend:
- openai
size: 1500
- name: dclm
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
backend:
- openai
size: 1500
- name: dolma-v17
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
backend:
- openai
size: 1500
- name: dolma-v1-small
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
backend:
- openai
size: 3000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 2
eval_batch_size: 2
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 5
eval_every_steps: 250
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: causal_lm
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
save:
path: s3://ai2-tylerm-experimental/experiments/rephrase/v1/models/lucas
save_every_steps: 500
max_workers: 1

View File

@ -11,7 +11,7 @@ class ModelConfig:
"""Configuration for loading a model; includes model name and type."""
name_or_path: str = field(
help="The model name or path to load; must be compatible with huggingface transformers.",
help="The model name or path to load; must be compatible with huggingface transformers."
)
arch: str = field(help="The model type to load; can be 'vllm', 'causal', or 'vllm'")
dtype: str = field(help="The precision to use for the model", default="bfloat16")

View File

@ -13,6 +13,7 @@ import os
import json
import base64
import logging
import time
from io import BytesIO
from PIL import Image
from functools import partial
@ -62,7 +63,26 @@ from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
def get_rank() -> int:
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank()
return 0
def run_train(config: TrainConfig):
if get_rank() == 0:
logger_level = logging.INFO
else:
logger_level = logging.WARN
disable_progress_bars()
logger = get_logger(__name__, level=logger_level)
set_verbosity(logger_level)
run_name = RunName.get(config)
accelerator = accelerate.Accelerator()
train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
@ -75,14 +95,66 @@ def run_train(config: TrainConfig):
train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(train_ds)
print("---------------")
dataloader = DataLoader(train_ds, batch_size=1, shuffle=False)
train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
for batch in dataloader:
print(batch)
result = model.forward(**batch)
with TemporaryDirectory() as output_dir:
training_args = TrainingArguments(
run_name=run_name.run,
logging_steps=config.hparams.log_every_steps,
output_dir=output_dir,
eval_strategy="steps",
report_to="wandb",
# report_to=[], # disable logging to wandb, we will use a custom callback
optim=config.hparams.optim,
eval_steps=config.hparams.eval_every_steps,
learning_rate=config.hparams.learning_rate,
per_device_train_batch_size=config.hparams.batch_size,
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
gradient_checkpointing=config.hparams.gradient_checkpointing,
gradient_checkpointing_kwargs=(
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
if config.hparams.gradient_checkpointing and config.lora is not None
else {}
),
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
max_steps=config.hparams.max_steps,
weight_decay=config.hparams.weight_decay,
dataloader_num_workers=config.max_workers,
load_best_model_at_end=True,
save_strategy="steps",
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=accelerator.mixed_precision == "bf16",
fp16=accelerator.mixed_precision == "fp16",
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
)
# Set the collator
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
#checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#data_collator=collator,
#callbacks=[checkpoint_callback],
)
# Train the model
trainer.train() # pyright: ignore
def main():