Prepping data to be in a trainable format

This commit is contained in:
Jake Poznanski 2024-09-20 09:25:54 -07:00
parent dc86a99a97
commit fcb67ebd61
2 changed files with 91 additions and 40 deletions

82
pdelfin/train/dataprep.py Normal file
View File

@ -0,0 +1,82 @@
import numpy as np
from io import BytesIO
from PIL import Image
import base64
def prepare_data_for_qwen2_training(example, processor):
# Prepare messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
# Get labels by tokenizing the output text
labels = processor(
text=[example["response"]],
padding=True,
return_tensors="np"
)
# Concatenate input_ids and labels
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
attention_mask = np.concatenate([inputs.attention_mask[0], labels.attention_mask[0]], axis=0)
# Create labels, masking the input portion with -100
labels_full = np.full_like(input_ids, fill_value=-100)
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
# Return as dict, including pixel_values
return {
"input_ids": input_ids.tolist(),
"attention_mask": attention_mask.tolist(),
"labels": labels_full.tolist(),
"pixel_values": inputs.pixel_values[0]
}
# Define a custom data collator
class DataCollatorForVisionLanguageModeling:
def __init__(self, processor):
self.processor = processor
def __call__(self, features):
input_ids = [f['input_ids'] for f in features]
attention_mask = [f['attention_mask'] for f in features]
labels = [f['labels'] for f in features]
pixel_values = [f['pixel_values'] for f in features]
# Pad input_ids, attention_mask, labels
batch = self.processor.pad(
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
return_tensors="pt",
padding=True,
)
# Stack pixel_values
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
return batch

View File

@ -10,6 +10,7 @@
# Step 5. Move over from interactive session to gantry launch script # Step 5. Move over from interactive session to gantry launch script
import os import os
import json
import base64 import base64
import logging import logging
from io import BytesIO from io import BytesIO
@ -19,6 +20,7 @@ from logging import Logger
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
from tqdm import tqdm
import accelerate import accelerate
import torch import torch
@ -56,58 +58,25 @@ from .utils import (
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import prepare_data_for_qwen2_training
def run_train(config: TrainConfig): def run_train(config: TrainConfig):
train_ds = build_batch_query_response_vision_dataset( train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
) )
model = Qwen2VLForConditionalGeneration.from_pretrained( model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
) )
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
for entry in train_ds: train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor),
messages = [ remove_columns=train_ds.column_names)
{
"role": "user",
"content": [
{
"type": "image",
"image": entry["input_prompt_image_base64"]
},
{"type": "text", "text": entry["input_prompt_text"]},
],
}
]
# Preparation for inference print(train_ds)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"])))
inputs = processor(
text=[text],
images=[main_image],
#videos=video_inputs,
padding=True,
return_tensors="pt",
)
#inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)