Okay, reasonably happy with the dataprep pipeline

This commit is contained in:
Jake Poznanski 2024-09-20 13:04:47 -07:00
parent a47afe5c8d
commit 4eddb1b45f
2 changed files with 36 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import numpy as np
from io import BytesIO
from PIL import Image
import base64
import torch # Make sure to import torch as it's used in the DataCollator
def prepare_data_for_qwen2_training(example, processor):
# Prepare messages
@ -40,6 +40,12 @@ def prepare_data_for_qwen2_training(example, processor):
padding=True,
return_tensors="np"
)
# Append an <|im_end|>\n" to the labels, because this is what it would look like
# if we passed the whole message stream in there
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
labels['input_ids'] = np.concatenate([labels['input_ids'][0], im_end_tokens])
labels['input_ids'] = np.expand_dims(labels['input_ids'], axis=0)
# Concatenate input_ids and labels
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
@ -79,4 +85,4 @@ class DataCollatorForVisionLanguageModeling:
# Stack pixel_values
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])
return batch
return batch

View File

@ -42,7 +42,7 @@ class TestDataprep(unittest.TestCase):
}
]
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=True)
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
@ -63,3 +63,30 @@ class TestDataprep(unittest.TestCase):
print(training_inputs)
print(training_inputs["input_ids"].shape)
print("Original tokenization")
print(processor.tokenizer.decode(inference_inputs["input_ids"][0]))
print("\n\n")
print("Assembled tokenization")
print(processor.tokenizer.decode(training_inputs["input_ids"]))
print("\n\n")
# Make sure that the token streams are the same
self.assertEqual(processor.tokenizer.decode(inference_inputs["input_ids"][0]),
processor.tokenizer.decode(training_inputs["input_ids"]))
# Make sure that the labels are masked with -100s properly
# You only want the last assistant generation itself to be not -100, and thus contributing to the loss
# Find the positions where labels are not -100
non_masked_positions = training_inputs['labels'] != -100
# Extract the tokens at those positions
label_tokens = training_inputs['input_ids'][non_masked_positions]
# Decode those tokens
decoded_labels = processor.tokenizer.decode(label_tokens)
assistant_response_with_end = example["response"] + "<|im_end|>\n"
# Assert that the decoded labels match the assistant's response with <|im_end|>\n
self.assertEqual(decoded_labels, assistant_response_with_end)