diff --git a/pdelfin/prompts/anchor.py b/pdelfin/prompts/anchor.py index d994017..579ea1b 100644 --- a/pdelfin/prompts/anchor.py +++ b/pdelfin/prompts/anchor.py @@ -26,7 +26,7 @@ from pypdf.generic import RectangleObject from pdelfin.prompts._adv_anchor import mult -def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"]) -> str: +def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int=4000) -> str: assert page > 0, "Pages are 1-indexed in pdf-land" if pdf_engine == "pdftotext": @@ -54,7 +54,7 @@ def get_anchor_text(local_pdf_path: str, page: int, pdf_engine: Literal["pdftote return best_option elif pdf_engine == "pdfreport": - return _linearize_pdf_report(_pdf_report(local_pdf_path, page)) + return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length) else: raise NotImplementedError("Unknown engine") diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index a320d63..0e39d25 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -23,6 +23,9 @@ from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Quiet logs from pypdf and smart open +logging.getLogger("pypdf").setLevel(logging.ERROR) +logging.getLogger("smart_open").setLevel(logging.ERROR) def list_dataset_files(s3_glob_path: str): """ diff --git a/pdelfin/train/dataprep.py b/pdelfin/train/dataprep.py index de04da5..f3ed5e1 100644 --- a/pdelfin/train/dataprep.py +++ b/pdelfin/train/dataprep.py @@ -4,19 +4,15 @@ from PIL import Image import base64 import torch # Make sure to import torch as it's used in the DataCollator +from pdelfin.prompts.anchor import get_anchor_text from pdelfin.prompts import build_finetuning_prompt - -def filter_by_max_seq_len(example, processor, max_prompt_len: int=2200, max_response_len: int=2200): - if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len: - return False - - if len(processor.tokenizer.tokenize(example["response"])) > max_response_len: - return False - - return True +from pdelfin.data.renderpdf import render_pdf_to_base64png -def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False): +def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int): + anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len) + base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim) + # Prepare messages messages = [ { @@ -24,9 +20,9 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False): "content": [ { "type": "image", - "image": example["input_prompt_image_base64"] # Placeholder + "image": base64_page_image }, - {"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])}, + {"type": "text", "text": build_finetuning_prompt(anchor_text)}, ], } ] @@ -36,14 +32,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False): ) # Decode image from base64 - main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"]))) - - # Right now, we are going to downsample to 1024 on the longest dimension, because - # 2048 as we passed to OpenAI is too large for training - width, height = main_image.size - assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid" - main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS) - + main_image = Image.open(BytesIO(base64.b64decode(base64_page_image))) # Process inputs using processor inputs = processor( @@ -84,36 +73,30 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False): labels_full = np.full_like(input_ids, fill_value=-100) labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0] + # TODO Maybe cap the max length + # Return as dict, including pixel_values - if add_batch_dim: - return { - "input_ids": input_ids[np.newaxis, ...], - "attention_mask": attention_mask[np.newaxis, ...], - "labels": labels_full[np.newaxis, ...], - "pixel_values": inputs.pixel_values[np.newaxis, ...], - "image_grid_thw": inputs["image_grid_thw"] - } - else: - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels_full, - "pixel_values": inputs.pixel_values, - "image_grid_thw": inputs["image_grid_thw"][0] - } + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels_full, + "pixel_values": inputs.pixel_values, + "image_grid_thw": inputs["image_grid_thw"][0] + } -def batch_prepare_data_for_qwen2_training(batch, processor): +def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int): # Process each example in the batch using the helper function processed_examples = [] - for i in range(len(batch["input_prompt_image_base64"])): + for i in range(len(batch["response"])): example = { - "input_prompt_image_base64": batch["input_prompt_image_base64"][i], - "input_prompt_text": batch["input_prompt_text"][i], - "raw_page_text": batch["raw_page_text"][i], + "local_pdf_path": batch["local_pdf_path"][i], + "page_num": batch["page_num"][i], "response": batch["response"][i] } - processed_example = prepare_data_for_qwen2_training(example, processor) + processed_example = prepare_data_for_qwen2_training(example, processor, + target_longest_image_dim=target_longest_image_dim, + target_anchor_text_len=target_anchor_text_len) processed_examples.append(processed_example) return { @@ -124,96 +107,3 @@ def batch_prepare_data_for_qwen2_training(batch, processor): "image_grid_thw": [x["image_grid_thw"] for x in processed_examples], } - -def prepare_data_for_qwen2_inference(example, processor): - # Prepare messages - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": example["input_prompt_image_base64"] # Placeholder - }, - {"type": "text", "text": example["input_prompt_text"]}, - ], - } - ] - # Apply chat template to get the text - text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - # Decode image from base64 - main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"]))) - - # Right now, we are going to downsample to 1024 on the longest dimension, because - # 2048 as we passed to OpenAI is too large for training - width, height = main_image.size - if 1800 <= max(width, height) <= 2200: - main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS) - - - # Process inputs using processor - inputs = processor( - text=[text], - images=[main_image], - padding=True, - return_tensors="np", - ) - - input_ids = inputs["input_ids"][0] - - # All columns will participate in attention fully - attention_mask = np.ones_like(input_ids) - - # Return as dict, including pixel_values - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": inputs.pixel_values, - "image_grid_thw": inputs["image_grid_thw"][0] - } - - -def batch_prepare_data_for_qwen2_inference(batch, processor): - # Process each example in the batch using the helper function - processed_examples = [] - for i in range(len(batch["input_prompt_image_base64"])): - example = { - "input_prompt_image_base64": batch["input_prompt_image_base64"][i], - "input_prompt_text": batch["input_prompt_text"][i], - "raw_page_text": batch["raw_page_text"][i], - } - processed_example = prepare_data_for_qwen2_inference(example, processor) - processed_examples.append(processed_example) - - return { - "input_ids": [x["input_ids"] for x in processed_examples], - "attention_mask": [x["attention_mask"] for x in processed_examples], - "pixel_values": [x["pixel_values"] for x in processed_examples], - "image_grid_thw": [x["image_grid_thw"] for x in processed_examples], - } - -# Define a custom data collator -class DataCollatorForVisionLanguageModeling: - def __init__(self, processor): - self.processor = processor - - def __call__(self, features): - input_ids = [f['input_ids'] for f in features] - attention_mask = [f['attention_mask'] for f in features] - labels = [f['labels'] for f in features] - pixel_values = [f['pixel_values'] for f in features] - - # Pad input_ids, attention_mask, labels - batch = self.processor.pad( - {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}, - return_tensors="pt", - padding=True, - ) - - # Stack pixel_values - batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values]) - - return batch diff --git a/pyproject.toml b/pyproject.toml index 45b0570..dd3ea9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "ftfy", "bleach", "markdown2", + "filelock", ] license = {file = "LICENSE"} diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index b633ab2..a74aee6 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -12,7 +12,7 @@ from pdelfin.train.dataloader import ( list_dataset_files ) -from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training +from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training class TestBatchQueryResponseDataset(unittest.TestCase): @@ -31,28 +31,27 @@ class TestBatchQueryResponseDataset(unittest.TestCase): print(ds) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") - from pdelfin.train.dataprep import filter_by_max_seq_len - ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000)) + + ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000)) print(ds[0]) def testPlotSequenceLengthHistogram(self): import plotly.express as px - ds = build_batch_query_response_vision_dataset( - query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl", + ds = build_finetuning_dataset( response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", ) + + processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") + + ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000)) + processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") initial_len = len(ds) - from pdelfin.train.dataprep import filter_by_max_seq_len - print("Filtering on max sequence length") - ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200)) - - formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) - train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False) + train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False) max_seen_len = 0 steps = 0 @@ -81,43 +80,3 @@ class TestBatchQueryResponseDataset(unittest.TestCase): ) fig.write_image("sequence_lengths_histogram.png") - - def testExtractBatch(self): - query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3) - query_data = query_data["train"] - query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names) - - print(query_data) - print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"]) - - def testExtractResponse(self): - response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3) - response_data = response_data["train"] - - response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names) - - print(response_data) - print(response_data[0]) - - def testPyArrowDirectJson(self): - query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl" - response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json" - - all_files = list_dataset_files(query_glob_path) - - import pyarrow as pa - import pyarrow.json as paj - import pyarrow.compute as pc - import pyarrow.fs as fs - - s3 = fs.S3FileSystem() - - block_size = 200 * 1024**2 - - for file in all_files: - with s3.open_input_stream(file.replace("s3://", "")) as f: - table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size)) - - print(f"file {file} rows {table.num_rows}") - print(table.schema) -