Prepping to train

This commit is contained in:
Jake Poznanski 2024-10-16 13:18:24 -07:00
parent 9d647b13b8
commit cbc667ce78
5 changed files with 56 additions and 24 deletions

View File

@ -46,7 +46,7 @@ hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: false
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-4
max_steps: 10000

View File

@ -1,15 +1,16 @@
from transformers import (
AutoProcessor
AutoProcessor,
DataCollatorForSeq2Seq
)
from pdelfin.train.core.cli import make_cli
from pdelfin.train.core.config import TrainConfig
from tqdm import tqdm
from .utils import (
make_dataset
)
from torch.utils.data import DataLoader
def main():
train_config = make_cli(TrainConfig) # pyright: ignore
@ -19,14 +20,29 @@ def main():
print("Training dataset........")
print(train_dataset)
print(train_dataset[0])
print("\n\n")
print("Validation dataset........")
print(valid_dataset)
print(valid_dataset[list(valid_dataset.keys())[0]][0])
print("\n\n")
print("Datasets loaded into hugging face cache directory")
data_collator = DataCollatorForSeq2Seq(
tokenizer=processor.tokenizer, # use the processor's tokenizer
max_length=4096,
padding=False,
)
train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
max_seen_len = 0
for entry in tqdm(train_dataloader):
num_input_tokens = entry["input_ids"].shape[1]
max_seen_len = max(max_seen_len, num_input_tokens)
print(max_seen_len)
if __name__ == "__main__":
main()

View File

@ -26,7 +26,8 @@ from transformers import (
TrainerCallback,
TrainingArguments,
Qwen2VLForConditionalGeneration,
AutoProcessor
AutoProcessor,
DataCollatorForSeq2Seq
)
from transformers.integrations import WandbCallback
from transformers.trainer_callback import TrainerControl, TrainerState
@ -168,12 +169,16 @@ def run_train(config: TrainConfig):
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
eval_on_start=True,
#eval_on_start=True,
metric_for_best_model=config.valid_data.metric_for_best_model,
)
# Set the collator
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
data_collator = DataCollatorForSeq2Seq(
tokenizer=processor.tokenizer, # use the processor's tokenizer
max_length=config.generate.max_length,
padding=False,
)
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
# Initialize Trainer
@ -183,16 +188,10 @@ def run_train(config: TrainConfig):
train_dataset=train_dataset,
eval_dataset=valid_dataset,
tokenizer=processor.tokenizer,
#Collator is not needed as we are doing batch size 1 for now...
#data_collator=collator,
data_collator=data_collator,
callbacks=[checkpoint_callback],
)
# Could not get this to work
# if get_rank() == 0:
# # this is a hack to add script and peft config to wandb config
# update_wandb_config(config, trainer, model)
# Train the model
trainer.train() # pyright: ignore

View File

@ -48,21 +48,37 @@ def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) ->
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)
# Training sets get all concatenated and shuffled
# Retrieve the two target lengths from the first source for comparison
first_source = config.train_data.sources[0]
target_longest_image_dim = first_source.target_longest_image_dim
target_anchor_text_len = first_source.target_anchor_text_len
# Verify that all sources have the same target lengths
for source in config.train_data.sources:
if source.target_longest_image_dim != target_longest_image_dim:
raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}")
if source.target_anchor_text_len != target_anchor_text_len:
raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}")
# Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library
train_dataset = concatenate_datasets(
[
get_rawdataset_from_source(config.train_data, source).with_transform(
partial(
batch_prepare_data_for_qwen2_training,
processor=processor,
target_longest_image_dim=source.target_longest_image_dim,
target_anchor_text_len=source.target_anchor_text_len,
)
)
get_rawdataset_from_source(config.train_data, source)
for source in config.train_data.sources
]
)
# Apply the transform to the concatenated dataset
train_dataset = train_dataset.with_transform(
partial(
batch_prepare_data_for_qwen2_training,
processor=processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len,
)
)
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{

View File

@ -24,6 +24,7 @@ dependencies = [
"pypdf",
"pymupdf",
"pypdfium2",
"cryptography",
"lingua-language-detector",
"Pillow",
"ftfy",