mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Okay, reasonably happy with the dataprep pipeline
This commit is contained in:
parent
a47afe5c8d
commit
4eddb1b45f
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import base64
|
import base64
|
||||||
|
import torch # Make sure to import torch as it's used in the DataCollator
|
||||||
|
|
||||||
def prepare_data_for_qwen2_training(example, processor):
|
def prepare_data_for_qwen2_training(example, processor):
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
@ -40,6 +40,12 @@ def prepare_data_for_qwen2_training(example, processor):
|
|||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="np"
|
return_tensors="np"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Append an <|im_end|>\n" to the labels, because this is what it would look like
|
||||||
|
# if we passed the whole message stream in there
|
||||||
|
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
||||||
|
labels['input_ids'] = np.concatenate([labels['input_ids'][0], im_end_tokens])
|
||||||
|
labels['input_ids'] = np.expand_dims(labels['input_ids'], axis=0)
|
||||||
|
|
||||||
# Concatenate input_ids and labels
|
# Concatenate input_ids and labels
|
||||||
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
|
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
|
||||||
@ -79,4 +85,4 @@ class DataCollatorForVisionLanguageModeling:
|
|||||||
# Stack pixel_values
|
# Stack pixel_values
|
||||||
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
@ -42,7 +42,7 @@ class TestDataprep(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=True)
|
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
|
||||||
|
|
||||||
# Decode image from base64
|
# Decode image from base64
|
||||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||||
@ -63,3 +63,30 @@ class TestDataprep(unittest.TestCase):
|
|||||||
print(training_inputs)
|
print(training_inputs)
|
||||||
print(training_inputs["input_ids"].shape)
|
print(training_inputs["input_ids"].shape)
|
||||||
|
|
||||||
|
print("Original tokenization")
|
||||||
|
print(processor.tokenizer.decode(inference_inputs["input_ids"][0]))
|
||||||
|
print("\n\n")
|
||||||
|
|
||||||
|
print("Assembled tokenization")
|
||||||
|
print(processor.tokenizer.decode(training_inputs["input_ids"]))
|
||||||
|
print("\n\n")
|
||||||
|
|
||||||
|
# Make sure that the token streams are the same
|
||||||
|
self.assertEqual(processor.tokenizer.decode(inference_inputs["input_ids"][0]),
|
||||||
|
processor.tokenizer.decode(training_inputs["input_ids"]))
|
||||||
|
|
||||||
|
# Make sure that the labels are masked with -100s properly
|
||||||
|
# You only want the last assistant generation itself to be not -100, and thus contributing to the loss
|
||||||
|
|
||||||
|
# Find the positions where labels are not -100
|
||||||
|
non_masked_positions = training_inputs['labels'] != -100
|
||||||
|
|
||||||
|
# Extract the tokens at those positions
|
||||||
|
label_tokens = training_inputs['input_ids'][non_masked_positions]
|
||||||
|
|
||||||
|
# Decode those tokens
|
||||||
|
decoded_labels = processor.tokenizer.decode(label_tokens)
|
||||||
|
assistant_response_with_end = example["response"] + "<|im_end|>\n"
|
||||||
|
|
||||||
|
# Assert that the decoded labels match the assistant's response with <|im_end|>\n
|
||||||
|
self.assertEqual(decoded_labels, assistant_response_with_end)
|
Loading…
x
Reference in New Issue
Block a user