diff --git a/pdelfin/silver_data/buildsilver.py b/pdelfin/silver_data/buildsilver.py index 148810f..11fb505 100644 --- a/pdelfin/silver_data/buildsilver.py +++ b/pdelfin/silver_data/buildsilver.py @@ -22,7 +22,7 @@ def _build_prompt(base_text: str) -> str: return ( f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it. " f"Just return the plain text representation of this document as if you were reading it naturally.\n" - f"Turn equations into a LaTeX representation. Remove the headers and footers, but keep references and footnotes.\n" + f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n" f"Read any natural handwriting.\n" f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n" f"If there is no text at all that you think you should read, just output [NO TEXT].\n" diff --git a/pdelfin/train/dataprep.py b/pdelfin/train/dataprep.py index 3785b1b..4cafba0 100644 --- a/pdelfin/train/dataprep.py +++ b/pdelfin/train/dataprep.py @@ -10,7 +10,7 @@ def filter_by_max_seq_len(example, max_seq_len=4500): return sizes[-1] <= max_seq_len -def prepare_data_for_qwen2_training(example, processor): +def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False): # Prepare messages messages = [ { @@ -71,13 +71,22 @@ def prepare_data_for_qwen2_training(example, processor): labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0] # Return as dict, including pixel_values - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels_full, - "pixel_values": inputs.pixel_values, - "image_grid_thw": inputs["image_grid_thw"][0] - } + if add_batch_dim: + return { + "input_ids": input_ids[np.newaxis, ...], + "attention_mask": attention_mask[np.newaxis, ...], + "labels": labels_full[np.newaxis, ...], + "pixel_values": inputs.pixel_values[np.newaxis, ...], + "image_grid_thw": inputs["image_grid_thw"] + } + else: + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels_full, + "pixel_values": inputs.pixel_values, + "image_grid_thw": inputs["image_grid_thw"][0] + } def batch_prepare_data_for_qwen2_training(batch, processor): diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index d84e273..73156e9 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -143,8 +143,8 @@ def run_train(config: TrainConfig): train_ds = dataset["train"].to_iterable_dataset(num_shards=64) validation_ds = dataset["validation"] - train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len) - validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=validation_ds.column_names) + train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len) + validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True)), remove_columns=validation_ds.column_names) print(train_ds) print(validation_ds) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index dbce291..fd52d2d 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -94,7 +94,7 @@ class TestBatchQueryResponseDataset(unittest.TestCase): processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") formatted_dataset = dataset.to_iterable_dataset(num_shards=64) - formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500) + formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500) for entry in formatted_dataset: print(entry)