More fixes

This commit is contained in:
Jake Poznanski 2024-09-26 23:10:07 +00:00
parent d098a87ed2
commit c00e40d1c4
4 changed files with 21 additions and 12 deletions

View File

@ -22,7 +22,7 @@ def _build_prompt(base_text: str) -> str:
return (
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it. "
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
f"Turn equations into a LaTeX representation. Remove the headers and footers, but keep references and footnotes.\n"
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n"
f"Read any natural handwriting.\n"
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n"
f"If there is no text at all that you think you should read, just output [NO TEXT].\n"

View File

@ -10,7 +10,7 @@ def filter_by_max_seq_len(example, max_seq_len=4500):
return sizes[-1] <= max_seq_len
def prepare_data_for_qwen2_training(example, processor):
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
# Prepare messages
messages = [
{
@ -71,13 +71,22 @@ def prepare_data_for_qwen2_training(example, processor):
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
if add_batch_dim:
return {
"input_ids": input_ids[np.newaxis, ...],
"attention_mask": attention_mask[np.newaxis, ...],
"labels": labels_full[np.newaxis, ...],
"pixel_values": inputs.pixel_values[np.newaxis, ...],
"image_grid_thw": inputs["image_grid_thw"]
}
else:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
def batch_prepare_data_for_qwen2_training(batch, processor):

View File

@ -143,8 +143,8 @@ def run_train(config: TrainConfig):
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=validation_ds.column_names)
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True)), remove_columns=validation_ds.column_names)
print(train_ds)
print(validation_ds)

View File

@ -94,7 +94,7 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
for entry in formatted_dataset:
print(entry)