mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-06 04:52:56 +00:00
Getting ready to launch a new training run
This commit is contained in:
parent
1686790ac8
commit
0ddaf9023d
@ -28,21 +28,20 @@ generate:
|
|||||||
train_data:
|
train_data:
|
||||||
seed: 1337
|
seed: 1337
|
||||||
sources:
|
sources:
|
||||||
- name: openai_batch_data_v2
|
- name: openai_batch_data_v5_1_eval # TODO This is just for testing the job, once ready change to a real train dataset
|
||||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
|
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
|
||||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
|
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
|
||||||
backend:
|
|
||||||
- openai
|
|
||||||
size: 100_000
|
|
||||||
|
|
||||||
valid_data:
|
valid_data:
|
||||||
sources:
|
sources:
|
||||||
- name: openai_batch_data_eval_mini
|
- name: openai_batch_data_v5_1_eval
|
||||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
|
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
|
||||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
|
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
|
||||||
backend:
|
- name: openai_batch_data_v5_1_iabooks_eval
|
||||||
- openai
|
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
|
||||||
size: 100_000
|
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.jsonl
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||||
hparams:
|
hparams:
|
||||||
@ -52,10 +51,10 @@ hparams:
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
clip_grad_norm: 1.0
|
clip_grad_norm: 1.0
|
||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
max_steps: 5000
|
max_steps: 2000
|
||||||
pad_multiple_of: 16
|
pad_multiple_of: 16
|
||||||
log_every_steps: 50
|
log_every_steps: 50
|
||||||
eval_every_steps: 500
|
eval_every_steps: 100
|
||||||
optim: adamw_torch
|
optim: adamw_torch
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
weight_decay: 0.01
|
weight_decay: 0.01
|
||||||
|
|||||||
@ -75,10 +75,8 @@ class AwsConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SourceConfig:
|
class SourceConfig:
|
||||||
name: str = field(help="The name of the source")
|
name: str = field(help="The name of the source")
|
||||||
size: int = field(help="Limit size for the source")
|
|
||||||
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
|
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
|
||||||
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
|
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
|
||||||
backend: List[str] = field(help="The data generation backend to use to train the model")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from tqdm import tqdm
|
|||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from datasets import DatasetDict
|
from datasets import DatasetDict, concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars
|
from datasets.utils import disable_progress_bars
|
||||||
from datasets.utils.logging import set_verbosity
|
from datasets.utils.logging import set_verbosity
|
||||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||||
@ -49,7 +49,7 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from pdelfin.train.dataloader import make_dataset
|
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
|
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
|
||||||
|
|
||||||
|
|
||||||
@ -113,13 +113,6 @@ def run_train(config: TrainConfig):
|
|||||||
|
|
||||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||||
|
|
||||||
dataset = make_dataset(
|
|
||||||
train_data_config=config.train_data,
|
|
||||||
valid_data_config=config.valid_data,
|
|
||||||
num_proc=config.num_proc,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
||||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||||
@ -139,10 +132,28 @@ def run_train(config: TrainConfig):
|
|||||||
log_trainable_parameters(model=model, logger=logger)
|
log_trainable_parameters(model=model, logger=logger)
|
||||||
|
|
||||||
# Do final filtering, and prep for running model forward()
|
# Do final filtering, and prep for running model forward()
|
||||||
filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset})
|
|
||||||
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
# Training sets get all concatenated and shuffled
|
||||||
print(formatted_dataset)
|
train_dataset = (
|
||||||
print("---------------")
|
concatenate_datasets(
|
||||||
|
[
|
||||||
|
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||||
|
for source in config.train_data.sources
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||||
|
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validation sets get put into a datasetdict so each can report a loss separately
|
||||||
|
valid_dataset = DatasetDict(
|
||||||
|
**{
|
||||||
|
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||||
|
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||||
|
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||||
|
for source in config.valid_data.sources
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
save_path = join_path("", config.save.path, run_name.run)
|
save_path = join_path("", config.save.path, run_name.run)
|
||||||
|
|
||||||
@ -192,8 +203,8 @@ def run_train(config: TrainConfig):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=formatted_dataset["train"],
|
train_dataset=train_dataset,
|
||||||
eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
eval_dataset=valid_dataset,
|
||||||
tokenizer=processor.tokenizer,
|
tokenizer=processor.tokenizer,
|
||||||
#Collator is not needed as we are doing batch size 1 for now...
|
#Collator is not needed as we are doing batch size 1 for now...
|
||||||
#data_collator=collator,
|
#data_collator=collator,
|
||||||
@ -218,7 +229,6 @@ def run_train(config: TrainConfig):
|
|||||||
|
|
||||||
logger.info("Saved best model to %s", best_dir)
|
logger.info("Saved best model to %s", best_dir)
|
||||||
|
|
||||||
|
|
||||||
# Uncomment to test speed of data loader
|
# Uncomment to test speed of data loader
|
||||||
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||||
# for entry in tqdm(train_dataloader):
|
# for entry in tqdm(train_dataloader):
|
||||||
|
|||||||
@ -19,7 +19,7 @@ run_name=$(basename "$0" .sh)
|
|||||||
# --cluster 'ai2/allennlp-cirrascale' \
|
# --cluster 'ai2/allennlp-cirrascale' \
|
||||||
# --priority high \
|
# --priority high \
|
||||||
|
|
||||||
CLUSTER='jupiter'
|
CLUSTER='pluto'
|
||||||
|
|
||||||
gantry run \
|
gantry run \
|
||||||
--description "${run_name}"\
|
--description "${run_name}"\
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user