Refactoring of train dataloaders

This commit is contained in:
Jake Poznanski 2024-10-16 18:26:25 +00:00
parent 23d129fd2c
commit 3c1b7de293
5 changed files with 41 additions and 188 deletions

View File

@ -26,7 +26,7 @@ from pypdf.generic import RectangleObject
from pdelfin.prompts._adv_anchor import mult
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"]) -> str:
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int=4000) -> str:
assert page > 0, "Pages are 1-indexed in pdf-land"
if pdf_engine == "pdftotext":
@ -54,7 +54,7 @@ def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftote
return best_option
elif pdf_engine == "pdfreport":
return _linearize_pdf_report(_pdf_report(local_pdf_path, page))
return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
else:
raise NotImplementedError("Unknown engine")

View File

@ -23,6 +23,9 @@ from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
def list_dataset_files(s3_glob_path: str):
"""

View File

@ -4,19 +4,15 @@ from PIL import Image
import base64
import torch # Make sure to import torch as it's used in the DataCollator
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.prompts import build_finetuning_prompt
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2200, max_response_len: int=2200):
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
return False
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
return False
return True
from pdelfin.data.renderpdf import render_pdf_to_base64png
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int):
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Prepare messages
messages = [
{
@ -24,9 +20,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
"image": base64_page_image
},
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
],
}
]
@ -36,14 +32,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
# Right now, we are going to downsample to 1024 on the longest dimension, because
# 2048 as we passed to OpenAI is too large for training
width, height = main_image.size
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process inputs using processor
inputs = processor(
@ -84,36 +73,30 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
labels_full = np.full_like(input_ids, fill_value=-100)
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
# TODO Maybe cap the max length
# Return as dict, including pixel_values
if add_batch_dim:
return {
"input_ids": input_ids[np.newaxis, ...],
"attention_mask": attention_mask[np.newaxis, ...],
"labels": labels_full[np.newaxis, ...],
"pixel_values": inputs.pixel_values[np.newaxis, ...],
"image_grid_thw": inputs["image_grid_thw"]
}
else:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
def batch_prepare_data_for_qwen2_training(batch, processor):
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["input_prompt_image_base64"])):
for i in range(len(batch["response"])):
example = {
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
"input_prompt_text": batch["input_prompt_text"][i],
"raw_page_text": batch["raw_page_text"][i],
"local_pdf_path": batch["local_pdf_path"][i],
"page_num": batch["page_num"][i],
"response": batch["response"][i]
}
processed_example = prepare_data_for_qwen2_training(example, processor)
processed_example = prepare_data_for_qwen2_training(example, processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len)
processed_examples.append(processed_example)
return {
@ -124,96 +107,3 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
def prepare_data_for_qwen2_inference(example, processor):
# Prepare messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
# Right now, we are going to downsample to 1024 on the longest dimension, because
# 2048 as we passed to OpenAI is too large for training
width, height = main_image.size
if 1800 <= max(width, height) <= 2200:
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
input_ids = inputs["input_ids"][0]
# All columns will participate in attention fully
attention_mask = np.ones_like(input_ids)
# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
def batch_prepare_data_for_qwen2_inference(batch, processor):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["input_prompt_image_base64"])):
example = {
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
"input_prompt_text": batch["input_prompt_text"][i],
"raw_page_text": batch["raw_page_text"][i],
}
processed_example = prepare_data_for_qwen2_inference(example, processor)
processed_examples.append(processed_example)
return {
"input_ids": [x["input_ids"] for x in processed_examples],
"attention_mask": [x["attention_mask"] for x in processed_examples],
"pixel_values": [x["pixel_values"] for x in processed_examples],
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
# Define a custom data collator
class DataCollatorForVisionLanguageModeling:
def __init__(self, processor):
self.processor = processor
def __call__(self, features):
input_ids = [f['input_ids'] for f in features]
attention_mask = [f['attention_mask'] for f in features]
labels = [f['labels'] for f in features]
pixel_values = [f['pixel_values'] for f in features]
# Pad input_ids, attention_mask, labels
batch = self.processor.pad(
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
return_tensors="pt",
padding=True,
)
# Stack pixel_values
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
return batch

View File

@ -29,6 +29,7 @@ dependencies = [
"ftfy",
"bleach",
"markdown2",
"filelock",
]
license = {file = "LICENSE"}

View File

@ -12,7 +12,7 @@ from pdelfin.train.dataloader import (
list_dataset_files
)
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
class TestBatchQueryResponseDataset(unittest.TestCase):
@ -31,28 +31,27 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
print(ds)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
print(ds[0])
def testPlotSequenceLengthHistogram(self):
import plotly.express as px
ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
ds = build_finetuning_dataset(
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
initial_len = len(ds)
from pdelfin.train.dataprep import filter_by_max_seq_len
print("Filtering on max sequence length")
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200))
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False)
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
max_seen_len = 0
steps = 0
@ -81,43 +80,3 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
)
fig.write_image("sequence_lengths_histogram.png")
def testExtractBatch(self):
query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
print(query_data)
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
def testExtractResponse(self):
response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
print(response_data)
print(response_data[0])
def testPyArrowDirectJson(self):
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl"
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json"
all_files = list_dataset_files(query_glob_path)
import pyarrow as pa
import pyarrow.json as paj
import pyarrow.compute as pc
import pyarrow.fs as fs
s3 = fs.S3FileSystem()
block_size = 200 * 1024**2
for file in all_files:
with s3.open_input_stream(file.replace("s3://", "")) as f:
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
print(f"file {file} rows {table.num_rows}")
print(table.schema)