mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 11:04:25 +00:00
Refactoring of train dataloaders
This commit is contained in:
parent
23d129fd2c
commit
3c1b7de293
@ -26,7 +26,7 @@ from pypdf.generic import RectangleObject
|
||||
from pdelfin.prompts._adv_anchor import mult
|
||||
|
||||
|
||||
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"]) -> str:
|
||||
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int=4000) -> str:
|
||||
assert page > 0, "Pages are 1-indexed in pdf-land"
|
||||
|
||||
if pdf_engine == "pdftotext":
|
||||
@ -54,7 +54,7 @@ def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftote
|
||||
|
||||
return best_option
|
||||
elif pdf_engine == "pdfreport":
|
||||
return _linearize_pdf_report(_pdf_report(local_pdf_path, page))
|
||||
return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
|
||||
else:
|
||||
raise NotImplementedError("Unknown engine")
|
||||
|
||||
|
||||
@ -23,6 +23,9 @@ from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Quiet logs from pypdf and smart open
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
logging.getLogger("smart_open").setLevel(logging.ERROR)
|
||||
|
||||
def list_dataset_files(s3_glob_path: str):
|
||||
"""
|
||||
|
||||
@ -4,19 +4,15 @@ from PIL import Image
|
||||
import base64
|
||||
import torch # Make sure to import torch as it's used in the DataCollator
|
||||
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.prompts import build_finetuning_prompt
|
||||
|
||||
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2200, max_response_len: int=2200):
|
||||
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
|
||||
return False
|
||||
|
||||
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
|
||||
return False
|
||||
|
||||
return True
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int):
|
||||
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
@ -24,9 +20,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": example["input_prompt_image_base64"] # Placeholder
|
||||
"image": base64_page_image
|
||||
},
|
||||
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
],
|
||||
}
|
||||
]
|
||||
@ -36,14 +32,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
|
||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
||||
# 2048 as we passed to OpenAI is too large for training
|
||||
width, height = main_image.size
|
||||
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
|
||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
||||
|
||||
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = processor(
|
||||
@ -84,36 +73,30 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
labels_full = np.full_like(input_ids, fill_value=-100)
|
||||
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
||||
|
||||
# TODO Maybe cap the max length
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
if add_batch_dim:
|
||||
return {
|
||||
"input_ids": input_ids[np.newaxis, ...],
|
||||
"attention_mask": attention_mask[np.newaxis, ...],
|
||||
"labels": labels_full[np.newaxis, ...],
|
||||
"pixel_values": inputs.pixel_values[np.newaxis, ...],
|
||||
"image_grid_thw": inputs["image_grid_thw"]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels_full,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels_full,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_qwen2_training(batch, processor):
|
||||
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int):
|
||||
# Process each example in the batch using the helper function
|
||||
processed_examples = []
|
||||
for i in range(len(batch["input_prompt_image_base64"])):
|
||||
for i in range(len(batch["response"])):
|
||||
example = {
|
||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
||||
"input_prompt_text": batch["input_prompt_text"][i],
|
||||
"raw_page_text": batch["raw_page_text"][i],
|
||||
"local_pdf_path": batch["local_pdf_path"][i],
|
||||
"page_num": batch["page_num"][i],
|
||||
"response": batch["response"][i]
|
||||
}
|
||||
processed_example = prepare_data_for_qwen2_training(example, processor)
|
||||
processed_example = prepare_data_for_qwen2_training(example, processor,
|
||||
target_longest_image_dim=target_longest_image_dim,
|
||||
target_anchor_text_len=target_anchor_text_len)
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return {
|
||||
@ -124,96 +107,3 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_inference(example, processor):
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": example["input_prompt_image_base64"] # Placeholder
|
||||
},
|
||||
{"type": "text", "text": example["input_prompt_text"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Apply chat template to get the text
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
|
||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
||||
# 2048 as we passed to OpenAI is too large for training
|
||||
width, height = main_image.size
|
||||
if 1800 <= max(width, height) <= 2200:
|
||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
||||
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"][0]
|
||||
|
||||
# All columns will participate in attention fully
|
||||
attention_mask = np.ones_like(input_ids)
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_qwen2_inference(batch, processor):
|
||||
# Process each example in the batch using the helper function
|
||||
processed_examples = []
|
||||
for i in range(len(batch["input_prompt_image_base64"])):
|
||||
example = {
|
||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
||||
"input_prompt_text": batch["input_prompt_text"][i],
|
||||
"raw_page_text": batch["raw_page_text"][i],
|
||||
}
|
||||
processed_example = prepare_data_for_qwen2_inference(example, processor)
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return {
|
||||
"input_ids": [x["input_ids"] for x in processed_examples],
|
||||
"attention_mask": [x["attention_mask"] for x in processed_examples],
|
||||
"pixel_values": [x["pixel_values"] for x in processed_examples],
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
# Define a custom data collator
|
||||
class DataCollatorForVisionLanguageModeling:
|
||||
def __init__(self, processor):
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features):
|
||||
input_ids = [f['input_ids'] for f in features]
|
||||
attention_mask = [f['attention_mask'] for f in features]
|
||||
labels = [f['labels'] for f in features]
|
||||
pixel_values = [f['pixel_values'] for f in features]
|
||||
|
||||
# Pad input_ids, attention_mask, labels
|
||||
batch = self.processor.pad(
|
||||
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
# Stack pixel_values
|
||||
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
||||
|
||||
return batch
|
||||
|
||||
@ -29,6 +29,7 @@ dependencies = [
|
||||
"ftfy",
|
||||
"bleach",
|
||||
"markdown2",
|
||||
"filelock",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from pdelfin.train.dataloader import (
|
||||
list_dataset_files
|
||||
)
|
||||
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
||||
|
||||
|
||||
class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
@ -31,28 +31,27 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
print(ds)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
|
||||
|
||||
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
|
||||
|
||||
print(ds[0])
|
||||
|
||||
def testPlotSequenceLengthHistogram(self):
|
||||
import plotly.express as px
|
||||
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
ds = build_finetuning_dataset(
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
initial_len = len(ds)
|
||||
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
print("Filtering on max sequence length")
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200))
|
||||
|
||||
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False)
|
||||
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
|
||||
|
||||
max_seen_len = 0
|
||||
steps = 0
|
||||
@ -81,43 +80,3 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
)
|
||||
|
||||
fig.write_image("sequence_lengths_histogram.png")
|
||||
|
||||
def testExtractBatch(self):
|
||||
query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
||||
|
||||
print(query_data)
|
||||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
|
||||
|
||||
def testExtractResponse(self):
|
||||
response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
|
||||
response_data = response_data["train"]
|
||||
|
||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
|
||||
|
||||
print(response_data)
|
||||
print(response_data[0])
|
||||
|
||||
def testPyArrowDirectJson(self):
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl"
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json"
|
||||
|
||||
all_files = list_dataset_files(query_glob_path)
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.json as paj
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as fs
|
||||
|
||||
s3 = fs.S3FileSystem()
|
||||
|
||||
block_size = 200 * 1024**2
|
||||
|
||||
for file in all_files:
|
||||
with s3.open_input_stream(file.replace("s3://", "")) as f:
|
||||
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
|
||||
|
||||
print(f"file {file} rows {table.num_rows}")
|
||||
print(table.schema)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user