mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-15 04:11:59 +00:00
More fixes
This commit is contained in:
parent
d098a87ed2
commit
c00e40d1c4
@ -22,7 +22,7 @@ def _build_prompt(base_text: str) -> str:
|
||||
return (
|
||||
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it. "
|
||||
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
|
||||
f"Turn equations into a LaTeX representation. Remove the headers and footers, but keep references and footnotes.\n"
|
||||
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n"
|
||||
f"Read any natural handwriting.\n"
|
||||
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n"
|
||||
f"If there is no text at all that you think you should read, just output [NO TEXT].\n"
|
||||
|
@ -10,7 +10,7 @@ def filter_by_max_seq_len(example, max_seq_len=4500):
|
||||
return sizes[-1] <= max_seq_len
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor):
|
||||
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
@ -71,6 +71,15 @@ def prepare_data_for_qwen2_training(example, processor):
|
||||
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
if add_batch_dim:
|
||||
return {
|
||||
"input_ids": input_ids[np.newaxis, ...],
|
||||
"attention_mask": attention_mask[np.newaxis, ...],
|
||||
"labels": labels_full[np.newaxis, ...],
|
||||
"pixel_values": inputs.pixel_values[np.newaxis, ...],
|
||||
"image_grid_thw": inputs["image_grid_thw"]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
@ -143,8 +143,8 @@ def run_train(config: TrainConfig):
|
||||
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
|
||||
validation_ds = dataset["validation"]
|
||||
|
||||
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
|
||||
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=validation_ds.column_names)
|
||||
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
|
||||
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True)), remove_columns=validation_ds.column_names)
|
||||
|
||||
print(train_ds)
|
||||
print(validation_ds)
|
||||
|
@ -94,7 +94,7 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
|
||||
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
|
||||
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
|
||||
|
||||
for entry in formatted_dataset:
|
||||
print(entry)
|
||||
|
Loading…
x
Reference in New Issue
Block a user