diff --git a/pdelfin/train/dataprep.py b/pdelfin/train/dataprep.py new file mode 100644 index 0000000..250afb4 --- /dev/null +++ b/pdelfin/train/dataprep.py @@ -0,0 +1,82 @@ +import numpy as np +from io import BytesIO +from PIL import Image +import base64 + + +def prepare_data_for_qwen2_training(example, processor): + # Prepare messages + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": example["input_prompt_image_base64"] # Placeholder + }, + {"type": "text", "text": example["input_prompt_text"]}, + ], + } + ] + # Apply chat template to get the text + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Decode image from base64 + main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"]))) + + # Process inputs using processor + inputs = processor( + text=[text], + images=[main_image], + padding=True, + return_tensors="np", + ) + + # Get labels by tokenizing the output text + labels = processor( + text=[example["response"]], + padding=True, + return_tensors="np" + ) + + # Concatenate input_ids and labels + input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0) + attention_mask = np.concatenate([inputs.attention_mask[0], labels.attention_mask[0]], axis=0) + + # Create labels, masking the input portion with -100 + labels_full = np.full_like(input_ids, fill_value=-100) + labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0] + + # Return as dict, including pixel_values + return { + "input_ids": input_ids.tolist(), + "attention_mask": attention_mask.tolist(), + "labels": labels_full.tolist(), + "pixel_values": inputs.pixel_values[0] + } + + +# Define a custom data collator +class DataCollatorForVisionLanguageModeling: + def __init__(self, processor): + self.processor = processor + + def __call__(self, features): + input_ids = [f['input_ids'] for f in features] + attention_mask = [f['attention_mask'] for f in features] + labels = [f['labels'] for f in features] + pixel_values = [f['pixel_values'] for f in features] + + # Pad input_ids, attention_mask, labels + batch = self.processor.pad( + {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}, + return_tensors="pt", + padding=True, + ) + + # Stack pixel_values + batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values]) + + return batch \ No newline at end of file diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index ff5b9b0..a2d6417 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -10,6 +10,7 @@ # Step 5. Move over from interactive session to gantry launch script import os +import json import base64 import logging from io import BytesIO @@ -19,6 +20,7 @@ from logging import Logger from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional +from tqdm import tqdm import accelerate import torch @@ -56,58 +58,25 @@ from .utils import ( from pdelfin.train.dataloader import build_batch_query_response_vision_dataset +from pdelfin.train.dataprep import prepare_data_for_qwen2_training def run_train(config: TrainConfig): train_ds = build_batch_query_response_vision_dataset( - query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", - response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", + query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", + response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json", ) model = Qwen2VLForConditionalGeneration.from_pretrained( - "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto" ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - for entry in train_ds: - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": entry["input_prompt_image_base64"] - }, - {"type": "text", "text": entry["input_prompt_text"]}, - ], - } - ] + train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), + remove_columns=train_ds.column_names) - # Preparation for inference - text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"]))) + print(train_ds) - inputs = processor( - text=[text], - images=[main_image], - #videos=video_inputs, - padding=True, - return_tensors="pt", - ) - #inputs = inputs.to("cuda") - - # Inference: Generation of the output - generated_ids = model.generate(**inputs, max_new_tokens=128) - generated_ids_trimmed = [ - out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - print(output_text)