From c93ac4a95ddfa16deff8f16e6f055c17c4d6fd11 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 12 Jun 2025 03:27:39 +0000 Subject: [PATCH] Cleaned up loader --- olmocr/train/dataloader.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 49fbe36..8b97578 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -292,6 +292,7 @@ class Tokenizer(PipelineStep): """Tokenizes messages and creates training labels with proper masking.""" processor: Any # The model processor (e.g., AutoProcessor) masking_index: int = -100 + end_of_message_token: str = "<|im_end|>" # Configurable, defaults to Qwen format def __call__(self, sample: Sample) -> Sample: """Tokenize messages and create labels for training.""" @@ -323,9 +324,9 @@ class Tokenizer(PipelineStep): # Get labels by tokenizing the output text labels = self.processor(text=[response], padding=True, return_tensors="np") - # Append <|im_end|>\n to the labels - im_end_tokens = self.processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"] - im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype) + # Append end-of-message token to the labels + end_tokens = self.processor.tokenizer(self.end_of_message_token, add_special_tokens=False)["input_ids"] + end_tokens = np.array(end_tokens, dtype=inputs.input_ids.dtype) # Handle the case where labels['input_ids'] is empty if labels["input_ids"].shape[1] == 0: @@ -333,7 +334,7 @@ class Tokenizer(PipelineStep): else: labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype) - labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens]) + labels["input_ids"] = np.concatenate([labels_input_ids_0, end_tokens]) labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0) # Concatenate input_ids and labels @@ -519,6 +520,29 @@ if __name__ == "__main__": print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}") + # Calculate and show token statistics after the table + print(f"\nToken statistics:") + + # Count consecutive high-value tokens that represent the image + # Qwen uses tokens like 151859, 151860, etc. for image patches + image_token_threshold = 151000 # Typical threshold for Qwen image tokens + image_token_count = np.sum(input_ids > image_token_threshold) + + # Calculate prompt tokens (everything masked) + prompt_token_count = masked_count + + # Calculate output tokens (everything not masked) + output_token_count = total_count - masked_count + + # Calculate non-image prompt tokens + non_image_prompt_tokens = prompt_token_count - image_token_count + + print(f" Image tokens: {image_token_count}") + print(f" Prompt tokens (total): {prompt_token_count}") + print(f" Prompt tokens (non-image): {non_image_prompt_tokens}") + print(f" Output tokens: {output_token_count}") + print(f" Total sequence length: {total_count}") + except ImportError as e: print(f"\nCould not import transformers: {e}") print("Install with: pip install transformers")