mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Okay, reasonably happy with the dataprep pipeline
This commit is contained in:
parent
a47afe5c8d
commit
4eddb1b45f
@ -2,7 +2,7 @@ import numpy as np
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import base64
|
||||
|
||||
import torch # Make sure to import torch as it's used in the DataCollator
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor):
|
||||
# Prepare messages
|
||||
@ -40,6 +40,12 @@ def prepare_data_for_qwen2_training(example, processor):
|
||||
padding=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
|
||||
# Append an <|im_end|>\n" to the labels, because this is what it would look like
|
||||
# if we passed the whole message stream in there
|
||||
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
||||
labels['input_ids'] = np.concatenate([labels['input_ids'][0], im_end_tokens])
|
||||
labels['input_ids'] = np.expand_dims(labels['input_ids'], axis=0)
|
||||
|
||||
# Concatenate input_ids and labels
|
||||
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
|
||||
@ -79,4 +85,4 @@ class DataCollatorForVisionLanguageModeling:
|
||||
# Stack pixel_values
|
||||
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
|
||||
|
||||
return batch
|
||||
return batch
|
||||
|
@ -42,7 +42,7 @@ class TestDataprep(unittest.TestCase):
|
||||
}
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=True)
|
||||
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
@ -63,3 +63,30 @@ class TestDataprep(unittest.TestCase):
|
||||
print(training_inputs)
|
||||
print(training_inputs["input_ids"].shape)
|
||||
|
||||
print("Original tokenization")
|
||||
print(processor.tokenizer.decode(inference_inputs["input_ids"][0]))
|
||||
print("\n\n")
|
||||
|
||||
print("Assembled tokenization")
|
||||
print(processor.tokenizer.decode(training_inputs["input_ids"]))
|
||||
print("\n\n")
|
||||
|
||||
# Make sure that the token streams are the same
|
||||
self.assertEqual(processor.tokenizer.decode(inference_inputs["input_ids"][0]),
|
||||
processor.tokenizer.decode(training_inputs["input_ids"]))
|
||||
|
||||
# Make sure that the labels are masked with -100s properly
|
||||
# You only want the last assistant generation itself to be not -100, and thus contributing to the loss
|
||||
|
||||
# Find the positions where labels are not -100
|
||||
non_masked_positions = training_inputs['labels'] != -100
|
||||
|
||||
# Extract the tokens at those positions
|
||||
label_tokens = training_inputs['input_ids'][non_masked_positions]
|
||||
|
||||
# Decode those tokens
|
||||
decoded_labels = processor.tokenizer.decode(label_tokens)
|
||||
assistant_response_with_end = example["response"] + "<|im_end|>\n"
|
||||
|
||||
# Assert that the decoded labels match the assistant's response with <|im_end|>\n
|
||||
self.assertEqual(decoded_labels, assistant_response_with_end)
|
Loading…
x
Reference in New Issue
Block a user