From cbc667ce7892dad87b78898d79b2197f9f1c1b8f Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 16 Oct 2024 13:18:24 -0700 Subject: [PATCH] Prepping to train --- pdelfin/train/config/qwen2vl-7b-lora.yaml | 2 +- pdelfin/train/loaddataset.py | 22 +++++++++++++-- pdelfin/train/train.py | 21 +++++++------- pdelfin/train/utils.py | 34 +++++++++++++++++------ pyproject.toml | 1 + 5 files changed, 56 insertions(+), 24 deletions(-) diff --git a/pdelfin/train/config/qwen2vl-7b-lora.yaml b/pdelfin/train/config/qwen2vl-7b-lora.yaml index 625ce2f..82fb0fd 100644 --- a/pdelfin/train/config/qwen2vl-7b-lora.yaml +++ b/pdelfin/train/config/qwen2vl-7b-lora.yaml @@ -46,7 +46,7 @@ hparams: batch_size: 1 eval_batch_size: 1 gradient_accumulation_steps: 4 - gradient_checkpointing: false + gradient_checkpointing: true clip_grad_norm: 1.0 learning_rate: 1e-4 max_steps: 10000 diff --git a/pdelfin/train/loaddataset.py b/pdelfin/train/loaddataset.py index bbcd3f0..c3a6917 100644 --- a/pdelfin/train/loaddataset.py +++ b/pdelfin/train/loaddataset.py @@ -1,15 +1,16 @@ from transformers import ( - AutoProcessor + AutoProcessor, + DataCollatorForSeq2Seq ) from pdelfin.train.core.cli import make_cli from pdelfin.train.core.config import TrainConfig - +from tqdm import tqdm from .utils import ( make_dataset ) - +from torch.utils.data import DataLoader def main(): train_config = make_cli(TrainConfig) # pyright: ignore @@ -19,14 +20,29 @@ def main(): print("Training dataset........") print(train_dataset) + print(train_dataset[0]) print("\n\n") print("Validation dataset........") print(valid_dataset) + print(valid_dataset[list(valid_dataset.keys())[0]][0]) print("\n\n") print("Datasets loaded into hugging face cache directory") + data_collator = DataCollatorForSeq2Seq( + tokenizer=processor.tokenizer, # use the processor's tokenizer + max_length=4096, + padding=False, + ) + + train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator) + max_seen_len = 0 + for entry in tqdm(train_dataloader): + num_input_tokens = entry["input_ids"].shape[1] + max_seen_len = max(max_seen_len, num_input_tokens) + + print(max_seen_len) if __name__ == "__main__": main() diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 6285e1f..4a0bebb 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -26,7 +26,8 @@ from transformers import ( TrainerCallback, TrainingArguments, Qwen2VLForConditionalGeneration, - AutoProcessor + AutoProcessor, + DataCollatorForSeq2Seq ) from transformers.integrations import WandbCallback from transformers.trainer_callback import TrainerControl, TrainerState @@ -168,12 +169,16 @@ def run_train(config: TrainConfig): label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885 max_grad_norm=config.hparams.clip_grad_norm, remove_unused_columns=False, - eval_on_start=True, + #eval_on_start=True, metric_for_best_model=config.valid_data.metric_for_best_model, ) - # Set the collator - collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False) + data_collator = DataCollatorForSeq2Seq( + tokenizer=processor.tokenizer, # use the processor's tokenizer + max_length=config.generate.max_length, + padding=False, + ) + checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger) # Initialize Trainer @@ -183,16 +188,10 @@ def run_train(config: TrainConfig): train_dataset=train_dataset, eval_dataset=valid_dataset, tokenizer=processor.tokenizer, - #Collator is not needed as we are doing batch size 1 for now... - #data_collator=collator, + data_collator=data_collator, callbacks=[checkpoint_callback], ) - # Could not get this to work - # if get_rank() == 0: - # # this is a hack to add script and peft config to wandb config - # update_wandb_config(config, trainer, model) - # Train the model trainer.train() # pyright: ignore diff --git a/pdelfin/train/utils.py b/pdelfin/train/utils.py index c2bb93e..1ea2097 100644 --- a/pdelfin/train/utils.py +++ b/pdelfin/train/utils.py @@ -48,21 +48,37 @@ def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]: random.seed(config.train_data.seed) - # Training sets get all concatenated and shuffled + # Retrieve the two target lengths from the first source for comparison + first_source = config.train_data.sources[0] + target_longest_image_dim = first_source.target_longest_image_dim + target_anchor_text_len = first_source.target_anchor_text_len + + # Verify that all sources have the same target lengths + for source in config.train_data.sources: + if source.target_longest_image_dim != target_longest_image_dim: + raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}") + if source.target_anchor_text_len != target_anchor_text_len: + raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}") + + + # Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library train_dataset = concatenate_datasets( [ - get_rawdataset_from_source(config.train_data, source).with_transform( - partial( - batch_prepare_data_for_qwen2_training, - processor=processor, - target_longest_image_dim=source.target_longest_image_dim, - target_anchor_text_len=source.target_anchor_text_len, - ) - ) + get_rawdataset_from_source(config.train_data, source) for source in config.train_data.sources ] ) + # Apply the transform to the concatenated dataset + train_dataset = train_dataset.with_transform( + partial( + batch_prepare_data_for_qwen2_training, + processor=processor, + target_longest_image_dim=target_longest_image_dim, + target_anchor_text_len=target_anchor_text_len, + ) + ) + # Validation sets get put into a datasetdict so each can report a loss separately valid_dataset = DatasetDict( **{ diff --git a/pyproject.toml b/pyproject.toml index dd3ea9c..96e35af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "pypdf", "pymupdf", "pypdfium2", + "cryptography", "lingua-language-detector", "Pillow", "ftfy",