Cleaned up loader

This commit is contained in:
Jake Poznanski 2025-06-12 03:27:39 +00:00
parent 60338810bc
commit c93ac4a95d

View File

@ -292,6 +292,7 @@ class Tokenizer(PipelineStep):
"""Tokenizes messages and creates training labels with proper masking."""
processor: Any # The model processor (e.g., AutoProcessor)
masking_index: int = -100
end_of_message_token: str = "<|im_end|>" # Configurable, defaults to Qwen format
def __call__(self, sample: Sample) -> Sample:
"""Tokenize messages and create labels for training."""
@ -323,9 +324,9 @@ class Tokenizer(PipelineStep):
# Get labels by tokenizing the output text
labels = self.processor(text=[response], padding=True, return_tensors="np")
# Append <|im_end|>\n to the labels
im_end_tokens = self.processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype)
# Append end-of-message token to the labels
end_tokens = self.processor.tokenizer(self.end_of_message_token, add_special_tokens=False)["input_ids"]
end_tokens = np.array(end_tokens, dtype=inputs.input_ids.dtype)
# Handle the case where labels['input_ids'] is empty
if labels["input_ids"].shape[1] == 0:
@ -333,7 +334,7 @@ class Tokenizer(PipelineStep):
else:
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
labels["input_ids"] = np.concatenate([labels_input_ids_0, end_tokens])
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
# Concatenate input_ids and labels
@ -519,6 +520,29 @@ if __name__ == "__main__":
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
# Calculate and show token statistics after the table
print(f"\nToken statistics:")
# Count consecutive high-value tokens that represent the image
# Qwen uses tokens like 151859, 151860, etc. for image patches
image_token_threshold = 151000 # Typical threshold for Qwen image tokens
image_token_count = np.sum(input_ids > image_token_threshold)
# Calculate prompt tokens (everything masked)
prompt_token_count = masked_count
# Calculate output tokens (everything not masked)
output_token_count = total_count - masked_count
# Calculate non-image prompt tokens
non_image_prompt_tokens = prompt_token_count - image_token_count
print(f" Image tokens: {image_token_count}")
print(f" Prompt tokens (total): {prompt_token_count}")
print(f" Prompt tokens (non-image): {non_image_prompt_tokens}")
print(f" Output tokens: {output_token_count}")
print(f" Total sequence length: {total_count}")
except ImportError as e:
print(f"\nCould not import transformers: {e}")
print("Install with: pip install transformers")