mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-14 17:52:53 +00:00
Cleaned up loader
This commit is contained in:
parent
60338810bc
commit
c93ac4a95d
@ -292,6 +292,7 @@ class Tokenizer(PipelineStep):
|
|||||||
"""Tokenizes messages and creates training labels with proper masking."""
|
"""Tokenizes messages and creates training labels with proper masking."""
|
||||||
processor: Any # The model processor (e.g., AutoProcessor)
|
processor: Any # The model processor (e.g., AutoProcessor)
|
||||||
masking_index: int = -100
|
masking_index: int = -100
|
||||||
|
end_of_message_token: str = "<|im_end|>" # Configurable, defaults to Qwen format
|
||||||
|
|
||||||
def __call__(self, sample: Sample) -> Sample:
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
"""Tokenize messages and create labels for training."""
|
"""Tokenize messages and create labels for training."""
|
||||||
@ -323,9 +324,9 @@ class Tokenizer(PipelineStep):
|
|||||||
# Get labels by tokenizing the output text
|
# Get labels by tokenizing the output text
|
||||||
labels = self.processor(text=[response], padding=True, return_tensors="np")
|
labels = self.processor(text=[response], padding=True, return_tensors="np")
|
||||||
|
|
||||||
# Append <|im_end|>\n to the labels
|
# Append end-of-message token to the labels
|
||||||
im_end_tokens = self.processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
end_tokens = self.processor.tokenizer(self.end_of_message_token, add_special_tokens=False)["input_ids"]
|
||||||
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype)
|
end_tokens = np.array(end_tokens, dtype=inputs.input_ids.dtype)
|
||||||
|
|
||||||
# Handle the case where labels['input_ids'] is empty
|
# Handle the case where labels['input_ids'] is empty
|
||||||
if labels["input_ids"].shape[1] == 0:
|
if labels["input_ids"].shape[1] == 0:
|
||||||
@ -333,7 +334,7 @@ class Tokenizer(PipelineStep):
|
|||||||
else:
|
else:
|
||||||
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
|
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
|
||||||
|
|
||||||
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
|
labels["input_ids"] = np.concatenate([labels_input_ids_0, end_tokens])
|
||||||
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
|
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
|
||||||
|
|
||||||
# Concatenate input_ids and labels
|
# Concatenate input_ids and labels
|
||||||
@ -519,6 +520,29 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
|
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
|
||||||
|
|
||||||
|
# Calculate and show token statistics after the table
|
||||||
|
print(f"\nToken statistics:")
|
||||||
|
|
||||||
|
# Count consecutive high-value tokens that represent the image
|
||||||
|
# Qwen uses tokens like 151859, 151860, etc. for image patches
|
||||||
|
image_token_threshold = 151000 # Typical threshold for Qwen image tokens
|
||||||
|
image_token_count = np.sum(input_ids > image_token_threshold)
|
||||||
|
|
||||||
|
# Calculate prompt tokens (everything masked)
|
||||||
|
prompt_token_count = masked_count
|
||||||
|
|
||||||
|
# Calculate output tokens (everything not masked)
|
||||||
|
output_token_count = total_count - masked_count
|
||||||
|
|
||||||
|
# Calculate non-image prompt tokens
|
||||||
|
non_image_prompt_tokens = prompt_token_count - image_token_count
|
||||||
|
|
||||||
|
print(f" Image tokens: {image_token_count}")
|
||||||
|
print(f" Prompt tokens (total): {prompt_token_count}")
|
||||||
|
print(f" Prompt tokens (non-image): {non_image_prompt_tokens}")
|
||||||
|
print(f" Output tokens: {output_token_count}")
|
||||||
|
print(f" Total sequence length: {total_count}")
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"\nCould not import transformers: {e}")
|
print(f"\nCould not import transformers: {e}")
|
||||||
print("Install with: pip install transformers")
|
print("Install with: pip install transformers")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user