mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-17 18:26:46 +00:00
Prepping data to be in a trainable format
This commit is contained in:
parent
dc86a99a97
commit
fcb67ebd61
82
pdelfin/train/dataprep.py
Normal file
82
pdelfin/train/dataprep.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import numpy as np
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data_for_qwen2_training(example, processor):
|
||||||
|
# Prepare messages
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": example["input_prompt_image_base64"] # Placeholder
|
||||||
|
},
|
||||||
|
{"type": "text", "text": example["input_prompt_text"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# Apply chat template to get the text
|
||||||
|
text = processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode image from base64
|
||||||
|
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||||
|
|
||||||
|
# Process inputs using processor
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=[main_image],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get labels by tokenizing the output text
|
||||||
|
labels = processor(
|
||||||
|
text=[example["response"]],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate input_ids and labels
|
||||||
|
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
|
||||||
|
attention_mask = np.concatenate([inputs.attention_mask[0], labels.attention_mask[0]], axis=0)
|
||||||
|
|
||||||
|
# Create labels, masking the input portion with -100
|
||||||
|
labels_full = np.full_like(input_ids, fill_value=-100)
|
||||||
|
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
||||||
|
|
||||||
|
# Return as dict, including pixel_values
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids.tolist(),
|
||||||
|
"attention_mask": attention_mask.tolist(),
|
||||||
|
"labels": labels_full.tolist(),
|
||||||
|
"pixel_values": inputs.pixel_values[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Define a custom data collator
|
||||||
|
class DataCollatorForVisionLanguageModeling:
|
||||||
|
def __init__(self, processor):
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
input_ids = [f['input_ids'] for f in features]
|
||||||
|
attention_mask = [f['attention_mask'] for f in features]
|
||||||
|
labels = [f['labels'] for f in features]
|
||||||
|
pixel_values = [f['pixel_values'] for f in features]
|
||||||
|
|
||||||
|
# Pad input_ids, attention_mask, labels
|
||||||
|
batch = self.processor.pad(
|
||||||
|
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stack pixel_values
|
||||||
|
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
||||||
|
|
||||||
|
return batch
|
||||||
@ -10,6 +10,7 @@
|
|||||||
# Step 5. Move over from interactive session to gantry launch script
|
# Step 5. Move over from interactive session to gantry launch script
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -19,6 +20,7 @@ from logging import Logger
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
@ -56,58 +58,25 @@ from .utils import (
|
|||||||
|
|
||||||
|
|
||||||
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||||
|
from pdelfin.train.dataprep import prepare_data_for_qwen2_training
|
||||||
|
|
||||||
|
|
||||||
def run_train(config: TrainConfig):
|
def run_train(config: TrainConfig):
|
||||||
train_ds = build_batch_query_response_vision_dataset(
|
train_ds = build_batch_query_response_vision_dataset(
|
||||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
|
||||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
|
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
|
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
|
||||||
)
|
)
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
|
||||||
for entry in train_ds:
|
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor),
|
||||||
messages = [
|
remove_columns=train_ds.column_names)
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": entry["input_prompt_image_base64"]
|
|
||||||
},
|
|
||||||
{"type": "text", "text": entry["input_prompt_text"]},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Preparation for inference
|
print(train_ds)
|
||||||
text = processor.apply_chat_template(
|
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"])))
|
|
||||||
|
|
||||||
inputs = processor(
|
|
||||||
text=[text],
|
|
||||||
images=[main_image],
|
|
||||||
#videos=video_inputs,
|
|
||||||
padding=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
#inputs = inputs.to("cuda")
|
|
||||||
|
|
||||||
# Inference: Generation of the output
|
|
||||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
|
||||||
generated_ids_trimmed = [
|
|
||||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
||||||
]
|
|
||||||
output_text = processor.batch_decode(
|
|
||||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
print(output_text)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user