mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
Getting ready to launch a new training run
This commit is contained in:
parent
1686790ac8
commit
0ddaf9023d
@ -28,21 +28,20 @@ generate:
|
||||
train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
- name: openai_batch_data_v2
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
- name: openai_batch_data_v5_1_eval # TODO This is just for testing the job, once ready change to a real train dataset
|
||||
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
|
||||
|
||||
valid_data:
|
||||
sources:
|
||||
- name: openai_batch_data_eval_mini
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.jsonl
|
||||
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
@ -52,10 +51,10 @@ hparams:
|
||||
gradient_checkpointing: false
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 5000
|
||||
max_steps: 2000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 50
|
||||
eval_every_steps: 500
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
|
||||
@ -75,10 +75,8 @@ class AwsConfig:
|
||||
@dataclass
|
||||
class SourceConfig:
|
||||
name: str = field(help="The name of the source")
|
||||
size: int = field(help="Limit size for the source")
|
||||
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
|
||||
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
|
||||
backend: List[str] = field(help="The data generation backend to use to train the model")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -15,7 +15,7 @@ from tqdm import tqdm
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
from datasets import DatasetDict
|
||||
from datasets import DatasetDict, concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
@ -49,7 +49,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
from pdelfin.train.dataloader import make_dataset
|
||||
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
|
||||
|
||||
|
||||
@ -113,13 +113,6 @@ def run_train(config: TrainConfig):
|
||||
|
||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||
|
||||
dataset = make_dataset(
|
||||
train_data_config=config.train_data,
|
||||
valid_data_config=config.valid_data,
|
||||
num_proc=config.num_proc,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
@ -139,15 +132,33 @@ def run_train(config: TrainConfig):
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
# Do final filtering, and prep for running model forward()
|
||||
filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset})
|
||||
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
|
||||
# Training sets get all concatenated and shuffled
|
||||
train_dataset = (
|
||||
concatenate_datasets(
|
||||
[
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
for source in config.train_data.sources
|
||||
]
|
||||
)
|
||||
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
)
|
||||
|
||||
# Validation sets get put into a datasetdict so each can report a loss separately
|
||||
valid_dataset = DatasetDict(
|
||||
**{
|
||||
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
for source in config.valid_data.sources
|
||||
}
|
||||
)
|
||||
|
||||
save_path = join_path("", config.save.path, run_name.run)
|
||||
|
||||
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
||||
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
|
||||
training_args = TrainingArguments(
|
||||
@ -192,8 +203,8 @@ def run_train(config: TrainConfig):
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=formatted_dataset["train"],
|
||||
eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=valid_dataset,
|
||||
tokenizer=processor.tokenizer,
|
||||
#Collator is not needed as we are doing batch size 1 for now...
|
||||
#data_collator=collator,
|
||||
@ -215,9 +226,8 @@ def run_train(config: TrainConfig):
|
||||
logger.info("LoRA adapters merged successfully.")
|
||||
|
||||
model.save_pretrained(best_dir)
|
||||
|
||||
logger.info("Saved best model to %s", best_dir)
|
||||
|
||||
logger.info("Saved best model to %s", best_dir)
|
||||
|
||||
# Uncomment to test speed of data loader
|
||||
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||
@ -232,4 +242,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -19,7 +19,7 @@ run_name=$(basename "$0" .sh)
|
||||
# --cluster 'ai2/allennlp-cirrascale' \
|
||||
# --priority high \
|
||||
|
||||
CLUSTER='jupiter'
|
||||
CLUSTER='pluto'
|
||||
|
||||
gantry run \
|
||||
--description "${run_name}"\
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user