mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-12 15:51:26 +00:00
Hoping to get a basic hf Trainer to run
This commit is contained in:
parent
55035b02c9
commit
256d77c232
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,6 @@
|
||||
# ml stuff
|
||||
wandb/
|
||||
|
||||
# build artifacts
|
||||
|
||||
.eggs/
|
||||
|
||||
124
pdelfin/train/config/qwen2vl-2b.yaml
Normal file
124
pdelfin/train/config/qwen2vl-2b.yaml
Normal file
@ -0,0 +1,124 @@
|
||||
model:
|
||||
name_or_path: Qwen/Qwen2-VL-2B-Instruct
|
||||
arch: causal
|
||||
|
||||
wandb:
|
||||
project: refine
|
||||
entity: pdf-qwen2vl
|
||||
|
||||
# TODO This is not used
|
||||
format:
|
||||
instruction_template: "Original:"
|
||||
response_template: "Rewritten:"
|
||||
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
|
||||
chat_template: |
|
||||
{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
|
||||
{% if loop.last %}
|
||||
{{ '<|im_end|>'}}
|
||||
{% else %}
|
||||
{{ '<|im_end|>\n' }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
generate:
|
||||
max_length: 4096
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
- name: fw-edu-all
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-all/*.json.gz
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
- name: dclm
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
- name: dolma-v17
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
- name: dolma-v1-small
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
valid_data:
|
||||
sources:
|
||||
- name: fw-edu-10k
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-10k/valid/*.gz
|
||||
backend:
|
||||
- openai
|
||||
size: 1500
|
||||
- name: dolma-10k
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-10k/valid/*.gz
|
||||
backend:
|
||||
- openai
|
||||
size: 1500
|
||||
- name: dclm
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
|
||||
backend:
|
||||
- openai
|
||||
size: 1500
|
||||
- name: dolma-v17
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
|
||||
backend:
|
||||
- openai
|
||||
size: 1500
|
||||
- name: dolma-v1-small
|
||||
paths:
|
||||
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
|
||||
backend:
|
||||
- openai
|
||||
size: 3000
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 2
|
||||
eval_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 5
|
||||
eval_every_steps: 250
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
lora:
|
||||
rank: 32
|
||||
alpha: 32
|
||||
dropout: 0.05
|
||||
task_type: causal_lm
|
||||
target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
save:
|
||||
path: s3://ai2-tylerm-experimental/experiments/rephrase/v1/models/lucas
|
||||
save_every_steps: 500
|
||||
|
||||
max_workers: 1
|
||||
@ -11,7 +11,7 @@ class ModelConfig:
|
||||
"""Configuration for loading a model; includes model name and type."""
|
||||
|
||||
name_or_path: str = field(
|
||||
help="The model name or path to load; must be compatible with huggingface transformers.",
|
||||
help="The model name or path to load; must be compatible with huggingface transformers."
|
||||
)
|
||||
arch: str = field(help="The model type to load; can be 'vllm', 'causal', or 'vllm'")
|
||||
dtype: str = field(help="The precision to use for the model", default="bfloat16")
|
||||
|
||||
@ -13,6 +13,7 @@ import os
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from functools import partial
|
||||
@ -62,7 +63,26 @@ from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
return 0
|
||||
|
||||
|
||||
def run_train(config: TrainConfig):
|
||||
if get_rank() == 0:
|
||||
logger_level = logging.INFO
|
||||
else:
|
||||
logger_level = logging.WARN
|
||||
disable_progress_bars()
|
||||
|
||||
logger = get_logger(__name__, level=logger_level)
|
||||
set_verbosity(logger_level)
|
||||
|
||||
run_name = RunName.get(config)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
train_ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
|
||||
@ -75,14 +95,66 @@ def run_train(config: TrainConfig):
|
||||
|
||||
train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(train_ds)
|
||||
print("---------------")
|
||||
|
||||
dataloader = DataLoader(train_ds, batch_size=1, shuffle=False)
|
||||
train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
|
||||
|
||||
for batch in dataloader:
|
||||
print(batch)
|
||||
|
||||
result = model.forward(**batch)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
|
||||
training_args = TrainingArguments(
|
||||
run_name=run_name.run,
|
||||
logging_steps=config.hparams.log_every_steps,
|
||||
output_dir=output_dir,
|
||||
eval_strategy="steps",
|
||||
report_to="wandb",
|
||||
# report_to=[], # disable logging to wandb, we will use a custom callback
|
||||
optim=config.hparams.optim,
|
||||
eval_steps=config.hparams.eval_every_steps,
|
||||
learning_rate=config.hparams.learning_rate,
|
||||
per_device_train_batch_size=config.hparams.batch_size,
|
||||
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
|
||||
gradient_checkpointing=config.hparams.gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs=(
|
||||
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
|
||||
if config.hparams.gradient_checkpointing and config.lora is not None
|
||||
else {}
|
||||
),
|
||||
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
|
||||
max_steps=config.hparams.max_steps,
|
||||
weight_decay=config.hparams.weight_decay,
|
||||
dataloader_num_workers=config.max_workers,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
|
||||
save_steps=config.save.save_every_steps,
|
||||
warmup_steps=config.hparams.warmup_steps,
|
||||
warmup_ratio=config.hparams.warmup_ratio,
|
||||
bf16=accelerator.mixed_precision == "bf16",
|
||||
fp16=accelerator.mixed_precision == "fp16",
|
||||
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
||||
max_grad_norm=config.hparams.clip_grad_norm,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
# Set the collator
|
||||
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
|
||||
#checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
||||
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
tokenizer=processor.tokenizer,
|
||||
#data_collator=collator,
|
||||
#callbacks=[checkpoint_callback],
|
||||
)
|
||||
|
||||
|
||||
# Train the model
|
||||
trainer.train() # pyright: ignore
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user