mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-14 09:42:47 +00:00
Cleaned up loader
This commit is contained in:
parent
60338810bc
commit
c93ac4a95d
@ -292,6 +292,7 @@ class Tokenizer(PipelineStep):
|
||||
"""Tokenizes messages and creates training labels with proper masking."""
|
||||
processor: Any # The model processor (e.g., AutoProcessor)
|
||||
masking_index: int = -100
|
||||
end_of_message_token: str = "<|im_end|>" # Configurable, defaults to Qwen format
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
"""Tokenize messages and create labels for training."""
|
||||
@ -323,9 +324,9 @@ class Tokenizer(PipelineStep):
|
||||
# Get labels by tokenizing the output text
|
||||
labels = self.processor(text=[response], padding=True, return_tensors="np")
|
||||
|
||||
# Append <|im_end|>\n to the labels
|
||||
im_end_tokens = self.processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
||||
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype)
|
||||
# Append end-of-message token to the labels
|
||||
end_tokens = self.processor.tokenizer(self.end_of_message_token, add_special_tokens=False)["input_ids"]
|
||||
end_tokens = np.array(end_tokens, dtype=inputs.input_ids.dtype)
|
||||
|
||||
# Handle the case where labels['input_ids'] is empty
|
||||
if labels["input_ids"].shape[1] == 0:
|
||||
@ -333,7 +334,7 @@ class Tokenizer(PipelineStep):
|
||||
else:
|
||||
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
|
||||
|
||||
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
|
||||
labels["input_ids"] = np.concatenate([labels_input_ids_0, end_tokens])
|
||||
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
|
||||
|
||||
# Concatenate input_ids and labels
|
||||
@ -519,6 +520,29 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
|
||||
|
||||
# Calculate and show token statistics after the table
|
||||
print(f"\nToken statistics:")
|
||||
|
||||
# Count consecutive high-value tokens that represent the image
|
||||
# Qwen uses tokens like 151859, 151860, etc. for image patches
|
||||
image_token_threshold = 151000 # Typical threshold for Qwen image tokens
|
||||
image_token_count = np.sum(input_ids > image_token_threshold)
|
||||
|
||||
# Calculate prompt tokens (everything masked)
|
||||
prompt_token_count = masked_count
|
||||
|
||||
# Calculate output tokens (everything not masked)
|
||||
output_token_count = total_count - masked_count
|
||||
|
||||
# Calculate non-image prompt tokens
|
||||
non_image_prompt_tokens = prompt_token_count - image_token_count
|
||||
|
||||
print(f" Image tokens: {image_token_count}")
|
||||
print(f" Prompt tokens (total): {prompt_token_count}")
|
||||
print(f" Prompt tokens (non-image): {non_image_prompt_tokens}")
|
||||
print(f" Output tokens: {output_token_count}")
|
||||
print(f" Total sequence length: {total_count}")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"\nCould not import transformers: {e}")
|
||||
print("Install with: pip install transformers")
|
||||
|
Loading…
x
Reference in New Issue
Block a user