olmocr/pdelfin/train/loaddataset.py

49 lines
1.3 KiB
Python
Raw Normal View History

2024-10-10 19:57:51 +00:00
from transformers import (
2024-10-16 13:18:24 -07:00
AutoProcessor,
DataCollatorForSeq2Seq
2024-10-10 19:57:51 +00:00
)
from pdelfin.train.core.cli import make_cli
from pdelfin.train.core.config import TrainConfig
2024-10-16 13:18:24 -07:00
from tqdm import tqdm
2024-10-10 19:57:51 +00:00
from .utils import (
make_dataset
)
2024-10-16 13:18:24 -07:00
from torch.utils.data import DataLoader
2024-10-10 19:57:51 +00:00
def main():
train_config = make_cli(TrainConfig) # pyright: ignore
processor = AutoProcessor.from_pretrained(train_config.model.name_or_path)
train_dataset, valid_dataset = make_dataset(train_config, processor)
print("Training dataset........")
print(train_dataset)
2024-10-16 13:18:24 -07:00
print(train_dataset[0])
2024-10-10 19:57:51 +00:00
print("\n\n")
print("Validation dataset........")
print(valid_dataset)
2024-10-16 13:18:24 -07:00
print(valid_dataset[list(valid_dataset.keys())[0]][0])
2024-10-10 19:57:51 +00:00
print("\n\n")
print("Datasets loaded into hugging face cache directory")
2024-10-16 13:18:24 -07:00
data_collator = DataCollatorForSeq2Seq(
tokenizer=processor.tokenizer, # use the processor's tokenizer
max_length=4096,
padding=False,
)
train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
max_seen_len = 0
for entry in tqdm(train_dataloader):
num_input_tokens = entry["input_ids"].shape[1]
max_seen_len = max(max_seen_len, num_input_tokens)
print(max_seen_len)
2024-10-10 19:57:51 +00:00
if __name__ == "__main__":
main()