mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-16 18:39:29 +00:00
Refactoring of train dataloaders
This commit is contained in:
parent
23d129fd2c
commit
3c1b7de293
@ -26,7 +26,7 @@ from pypdf.generic import RectangleObject
|
|||||||
from pdelfin.prompts._adv_anchor import mult
|
from pdelfin.prompts._adv_anchor import mult
|
||||||
|
|
||||||
|
|
||||||
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"]) -> str:
|
def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int=4000) -> str:
|
||||||
assert page > 0, "Pages are 1-indexed in pdf-land"
|
assert page > 0, "Pages are 1-indexed in pdf-land"
|
||||||
|
|
||||||
if pdf_engine == "pdftotext":
|
if pdf_engine == "pdftotext":
|
||||||
@ -54,7 +54,7 @@ def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftote
|
|||||||
|
|
||||||
return best_option
|
return best_option
|
||||||
elif pdf_engine == "pdfreport":
|
elif pdf_engine == "pdfreport":
|
||||||
return _linearize_pdf_report(_pdf_report(local_pdf_path, page))
|
return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown engine")
|
raise NotImplementedError("Unknown engine")
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,9 @@ from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Quiet logs from pypdf and smart open
|
||||||
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("smart_open").setLevel(logging.ERROR)
|
||||||
|
|
||||||
def list_dataset_files(s3_glob_path: str):
|
def list_dataset_files(s3_glob_path: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,19 +4,15 @@ from PIL import Image
|
|||||||
import base64
|
import base64
|
||||||
import torch # Make sure to import torch as it's used in the DataCollator
|
import torch # Make sure to import torch as it's used in the DataCollator
|
||||||
|
|
||||||
|
from pdelfin.prompts.anchor import get_anchor_text
|
||||||
from pdelfin.prompts import build_finetuning_prompt
|
from pdelfin.prompts import build_finetuning_prompt
|
||||||
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||||
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2200, max_response_len: int=2200):
|
|
||||||
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int):
|
||||||
|
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||||
|
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
||||||
|
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@ -24,9 +20,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": example["input_prompt_image_base64"] # Placeholder
|
"image": base64_page_image
|
||||||
},
|
},
|
||||||
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
|
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -36,14 +32,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Decode image from base64
|
# Decode image from base64
|
||||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
||||||
|
|
||||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
|
||||||
# 2048 as we passed to OpenAI is too large for training
|
|
||||||
width, height = main_image.size
|
|
||||||
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
|
|
||||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
|
||||||
|
|
||||||
|
|
||||||
# Process inputs using processor
|
# Process inputs using processor
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
@ -84,16 +73,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
|||||||
labels_full = np.full_like(input_ids, fill_value=-100)
|
labels_full = np.full_like(input_ids, fill_value=-100)
|
||||||
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
||||||
|
|
||||||
|
# TODO Maybe cap the max length
|
||||||
|
|
||||||
# Return as dict, including pixel_values
|
# Return as dict, including pixel_values
|
||||||
if add_batch_dim:
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids[np.newaxis, ...],
|
|
||||||
"attention_mask": attention_mask[np.newaxis, ...],
|
|
||||||
"labels": labels_full[np.newaxis, ...],
|
|
||||||
"pixel_values": inputs.pixel_values[np.newaxis, ...],
|
|
||||||
"image_grid_thw": inputs["image_grid_thw"]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
@ -103,17 +85,18 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def batch_prepare_data_for_qwen2_training(batch, processor):
|
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int):
|
||||||
# Process each example in the batch using the helper function
|
# Process each example in the batch using the helper function
|
||||||
processed_examples = []
|
processed_examples = []
|
||||||
for i in range(len(batch["input_prompt_image_base64"])):
|
for i in range(len(batch["response"])):
|
||||||
example = {
|
example = {
|
||||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
"local_pdf_path": batch["local_pdf_path"][i],
|
||||||
"input_prompt_text": batch["input_prompt_text"][i],
|
"page_num": batch["page_num"][i],
|
||||||
"raw_page_text": batch["raw_page_text"][i],
|
|
||||||
"response": batch["response"][i]
|
"response": batch["response"][i]
|
||||||
}
|
}
|
||||||
processed_example = prepare_data_for_qwen2_training(example, processor)
|
processed_example = prepare_data_for_qwen2_training(example, processor,
|
||||||
|
target_longest_image_dim=target_longest_image_dim,
|
||||||
|
target_anchor_text_len=target_anchor_text_len)
|
||||||
processed_examples.append(processed_example)
|
processed_examples.append(processed_example)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -124,96 +107,3 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
|
|||||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def prepare_data_for_qwen2_inference(example, processor):
|
|
||||||
# Prepare messages
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": example["input_prompt_image_base64"] # Placeholder
|
|
||||||
},
|
|
||||||
{"type": "text", "text": example["input_prompt_text"]},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
# Apply chat template to get the text
|
|
||||||
text = processor.apply_chat_template(
|
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode image from base64
|
|
||||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
|
||||||
|
|
||||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
|
||||||
# 2048 as we passed to OpenAI is too large for training
|
|
||||||
width, height = main_image.size
|
|
||||||
if 1800 <= max(width, height) <= 2200:
|
|
||||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
|
||||||
|
|
||||||
|
|
||||||
# Process inputs using processor
|
|
||||||
inputs = processor(
|
|
||||||
text=[text],
|
|
||||||
images=[main_image],
|
|
||||||
padding=True,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = inputs["input_ids"][0]
|
|
||||||
|
|
||||||
# All columns will participate in attention fully
|
|
||||||
attention_mask = np.ones_like(input_ids)
|
|
||||||
|
|
||||||
# Return as dict, including pixel_values
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"pixel_values": inputs.pixel_values,
|
|
||||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def batch_prepare_data_for_qwen2_inference(batch, processor):
|
|
||||||
# Process each example in the batch using the helper function
|
|
||||||
processed_examples = []
|
|
||||||
for i in range(len(batch["input_prompt_image_base64"])):
|
|
||||||
example = {
|
|
||||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
|
||||||
"input_prompt_text": batch["input_prompt_text"][i],
|
|
||||||
"raw_page_text": batch["raw_page_text"][i],
|
|
||||||
}
|
|
||||||
processed_example = prepare_data_for_qwen2_inference(example, processor)
|
|
||||||
processed_examples.append(processed_example)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": [x["input_ids"] for x in processed_examples],
|
|
||||||
"attention_mask": [x["attention_mask"] for x in processed_examples],
|
|
||||||
"pixel_values": [x["pixel_values"] for x in processed_examples],
|
|
||||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Define a custom data collator
|
|
||||||
class DataCollatorForVisionLanguageModeling:
|
|
||||||
def __init__(self, processor):
|
|
||||||
self.processor = processor
|
|
||||||
|
|
||||||
def __call__(self, features):
|
|
||||||
input_ids = [f['input_ids'] for f in features]
|
|
||||||
attention_mask = [f['attention_mask'] for f in features]
|
|
||||||
labels = [f['labels'] for f in features]
|
|
||||||
pixel_values = [f['pixel_values'] for f in features]
|
|
||||||
|
|
||||||
# Pad input_ids, attention_mask, labels
|
|
||||||
batch = self.processor.pad(
|
|
||||||
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stack pixel_values
|
|
||||||
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ dependencies = [
|
|||||||
"ftfy",
|
"ftfy",
|
||||||
"bleach",
|
"bleach",
|
||||||
"markdown2",
|
"markdown2",
|
||||||
|
"filelock",
|
||||||
]
|
]
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from pdelfin.train.dataloader import (
|
|||||||
list_dataset_files
|
list_dataset_files
|
||||||
)
|
)
|
||||||
|
|
||||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
||||||
|
|
||||||
|
|
||||||
class TestBatchQueryResponseDataset(unittest.TestCase):
|
class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||||
@ -31,28 +31,27 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
|||||||
print(ds)
|
print(ds)
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
|
||||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
|
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
|
||||||
|
|
||||||
print(ds[0])
|
print(ds[0])
|
||||||
|
|
||||||
def testPlotSequenceLengthHistogram(self):
|
def testPlotSequenceLengthHistogram(self):
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
|
|
||||||
ds = build_batch_query_response_vision_dataset(
|
ds = build_finetuning_dataset(
|
||||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
|
||||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||||
|
|
||||||
|
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||||
|
|
||||||
initial_len = len(ds)
|
initial_len = len(ds)
|
||||||
|
|
||||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
|
||||||
print("Filtering on max sequence length")
|
|
||||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200))
|
|
||||||
|
|
||||||
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
|
||||||
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False)
|
|
||||||
|
|
||||||
max_seen_len = 0
|
max_seen_len = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
@ -81,43 +80,3 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
fig.write_image("sequence_lengths_histogram.png")
|
fig.write_image("sequence_lengths_histogram.png")
|
||||||
|
|
||||||
def testExtractBatch(self):
|
|
||||||
query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
|
|
||||||
query_data = query_data["train"]
|
|
||||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
|
||||||
|
|
||||||
print(query_data)
|
|
||||||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
|
|
||||||
|
|
||||||
def testExtractResponse(self):
|
|
||||||
response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
|
|
||||||
response_data = response_data["train"]
|
|
||||||
|
|
||||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
|
|
||||||
|
|
||||||
print(response_data)
|
|
||||||
print(response_data[0])
|
|
||||||
|
|
||||||
def testPyArrowDirectJson(self):
|
|
||||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl"
|
|
||||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json"
|
|
||||||
|
|
||||||
all_files = list_dataset_files(query_glob_path)
|
|
||||||
|
|
||||||
import pyarrow as pa
|
|
||||||
import pyarrow.json as paj
|
|
||||||
import pyarrow.compute as pc
|
|
||||||
import pyarrow.fs as fs
|
|
||||||
|
|
||||||
s3 = fs.S3FileSystem()
|
|
||||||
|
|
||||||
block_size = 200 * 1024**2
|
|
||||||
|
|
||||||
for file in all_files:
|
|
||||||
with s3.open_input_stream(file.replace("s3://", "")) as f:
|
|
||||||
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
|
|
||||||
|
|
||||||
print(f"file {file} rows {table.num_rows}")
|
|
||||||
print(table.schema)
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user