2024-09-20 09:25:54 -07:00
|
|
|
import numpy as np
|
|
|
|
from io import BytesIO
|
|
|
|
from PIL import Image
|
2024-10-23 22:53:05 +00:00
|
|
|
from typing import Union
|
2024-09-20 09:25:54 -07:00
|
|
|
import base64
|
2024-10-23 15:37:23 -07:00
|
|
|
import random
|
2024-09-20 13:04:47 -07:00
|
|
|
import torch # Make sure to import torch as it's used in the DataCollator
|
2024-09-20 09:25:54 -07:00
|
|
|
|
2024-10-16 18:26:25 +00:00
|
|
|
from pdelfin.prompts.anchor import get_anchor_text
|
2024-09-30 18:48:17 +00:00
|
|
|
from pdelfin.prompts import build_finetuning_prompt
|
2024-10-16 18:26:25 +00:00
|
|
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
2024-09-26 21:39:08 +00:00
|
|
|
|
2024-09-27 15:48:56 +00:00
|
|
|
|
2024-10-23 22:53:05 +00:00
|
|
|
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
|
2024-10-23 15:37:23 -07:00
|
|
|
if isinstance(target_longest_image_dim, list):
|
|
|
|
target_longest_image_dim = random.choice(target_longest_image_dim)
|
|
|
|
|
|
|
|
if isinstance(target_anchor_text_len, list):
|
|
|
|
target_anchor_text_len = random.choice(target_anchor_text_len)
|
|
|
|
|
2024-10-16 18:26:25 +00:00
|
|
|
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
|
|
|
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
2024-09-26 21:39:08 +00:00
|
|
|
|
2024-09-20 09:25:54 -07:00
|
|
|
# Prepare messages
|
|
|
|
messages = [
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": [
|
|
|
|
{
|
|
|
|
"type": "image",
|
2024-10-16 18:26:25 +00:00
|
|
|
"image": base64_page_image
|
2024-09-20 09:25:54 -07:00
|
|
|
},
|
2024-10-16 18:26:25 +00:00
|
|
|
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
2024-09-20 09:25:54 -07:00
|
|
|
],
|
|
|
|
}
|
|
|
|
]
|
|
|
|
# Apply chat template to get the text
|
|
|
|
text = processor.apply_chat_template(
|
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
|
)
|
|
|
|
|
|
|
|
# Decode image from base64
|
2024-10-16 18:26:25 +00:00
|
|
|
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
2024-09-23 08:53:56 -07:00
|
|
|
|
2024-09-20 09:25:54 -07:00
|
|
|
# Process inputs using processor
|
|
|
|
inputs = processor(
|
|
|
|
text=[text],
|
|
|
|
images=[main_image],
|
|
|
|
padding=True,
|
|
|
|
return_tensors="np",
|
|
|
|
)
|
|
|
|
|
|
|
|
# Get labels by tokenizing the output text
|
|
|
|
labels = processor(
|
|
|
|
text=[example["response"]],
|
|
|
|
padding=True,
|
|
|
|
return_tensors="np"
|
|
|
|
)
|
2024-09-20 13:04:47 -07:00
|
|
|
|
|
|
|
# Append an <|im_end|>\n" to the labels, because this is what it would look like
|
|
|
|
# if we passed the whole message stream in there
|
|
|
|
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
2024-10-10 16:57:00 +00:00
|
|
|
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype) # Ensure correct dtype
|
|
|
|
|
|
|
|
# Handle the case where labels['input_ids'] is empty
|
|
|
|
if labels['input_ids'].shape[1] == 0:
|
|
|
|
labels_input_ids_0 = np.array([], dtype=inputs.input_ids.dtype)
|
|
|
|
else:
|
|
|
|
labels_input_ids_0 = labels['input_ids'][0].astype(inputs.input_ids.dtype)
|
2024-09-20 09:25:54 -07:00
|
|
|
|
2024-10-10 16:57:00 +00:00
|
|
|
labels['input_ids'] = np.concatenate([labels_input_ids_0, im_end_tokens])
|
|
|
|
labels['input_ids'] = np.expand_dims(labels['input_ids'], axis=0)
|
|
|
|
|
2024-09-20 09:25:54 -07:00
|
|
|
# Concatenate input_ids and labels
|
|
|
|
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
|
2024-09-23 08:53:56 -07:00
|
|
|
|
|
|
|
# All columns will participate in attention fully
|
|
|
|
attention_mask = np.ones_like(input_ids)
|
2024-09-20 09:25:54 -07:00
|
|
|
|
|
|
|
# Create labels, masking the input portion with -100
|
|
|
|
labels_full = np.full_like(input_ids, fill_value=-100)
|
|
|
|
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
|
|
|
|
2024-10-16 18:26:25 +00:00
|
|
|
# TODO Maybe cap the max length
|
2024-09-24 08:56:36 -07:00
|
|
|
|
|
|
|
# Return as dict, including pixel_values
|
|
|
|
return {
|
|
|
|
"input_ids": input_ids,
|
|
|
|
"attention_mask": attention_mask,
|
2024-10-16 18:26:25 +00:00
|
|
|
"labels": labels_full,
|
2024-09-24 08:56:36 -07:00
|
|
|
"pixel_values": inputs.pixel_values,
|
|
|
|
"image_grid_thw": inputs["image_grid_thw"][0]
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-10-24 03:07:32 +00:00
|
|
|
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
|
2024-09-24 08:56:36 -07:00
|
|
|
# Process each example in the batch using the helper function
|
|
|
|
processed_examples = []
|
2024-10-16 18:26:25 +00:00
|
|
|
for i in range(len(batch["response"])):
|
2024-09-24 08:56:36 -07:00
|
|
|
example = {
|
2024-10-16 18:26:25 +00:00
|
|
|
"local_pdf_path": batch["local_pdf_path"][i],
|
|
|
|
"page_num": batch["page_num"][i],
|
|
|
|
"response": batch["response"][i]
|
2024-09-24 08:56:36 -07:00
|
|
|
}
|
2024-10-16 18:26:25 +00:00
|
|
|
processed_example = prepare_data_for_qwen2_training(example, processor,
|
|
|
|
target_longest_image_dim=target_longest_image_dim,
|
|
|
|
target_anchor_text_len=target_anchor_text_len)
|
2024-09-24 08:56:36 -07:00
|
|
|
processed_examples.append(processed_example)
|
|
|
|
|
|
|
|
return {
|
|
|
|
"input_ids": [x["input_ids"] for x in processed_examples],
|
|
|
|
"attention_mask": [x["attention_mask"] for x in processed_examples],
|
2024-10-16 18:26:25 +00:00
|
|
|
"labels": [x["labels"] for x in processed_examples],
|
2024-09-24 08:56:36 -07:00
|
|
|
"pixel_values": [x["pixel_values"] for x in processed_examples],
|
|
|
|
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
|
|
|
}
|
|
|
|
|