2024-10-10 19:57:51 +00:00
|
|
|
from transformers import (
|
2024-10-16 13:18:24 -07:00
|
|
|
AutoProcessor,
|
|
|
|
DataCollatorForSeq2Seq
|
2024-10-10 19:57:51 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
from pdelfin.train.core.cli import make_cli
|
|
|
|
from pdelfin.train.core.config import TrainConfig
|
2024-10-16 13:18:24 -07:00
|
|
|
from tqdm import tqdm
|
2024-10-10 19:57:51 +00:00
|
|
|
from .utils import (
|
2024-10-16 13:28:12 -07:00
|
|
|
make_dataset, TruncatingCollator
|
2024-10-10 19:57:51 +00:00
|
|
|
)
|
|
|
|
|
2024-10-16 13:18:24 -07:00
|
|
|
from torch.utils.data import DataLoader
|
2024-10-10 19:57:51 +00:00
|
|
|
|
|
|
|
def main():
|
|
|
|
train_config = make_cli(TrainConfig) # pyright: ignore
|
|
|
|
|
2024-10-30 21:22:39 +00:00
|
|
|
processor = AutoProcessor.from_pretrained(train_config.model.name_or_path, trust_remote_code=True)
|
2024-10-10 19:57:51 +00:00
|
|
|
train_dataset, valid_dataset = make_dataset(train_config, processor)
|
|
|
|
|
|
|
|
print("Training dataset........")
|
|
|
|
print(train_dataset)
|
2024-10-16 13:18:24 -07:00
|
|
|
print(train_dataset[0])
|
2024-10-10 19:57:51 +00:00
|
|
|
print("\n\n")
|
|
|
|
|
|
|
|
print("Validation dataset........")
|
|
|
|
print(valid_dataset)
|
2024-10-16 13:18:24 -07:00
|
|
|
print(valid_dataset[list(valid_dataset.keys())[0]][0])
|
2024-10-10 19:57:51 +00:00
|
|
|
print("\n\n")
|
|
|
|
|
|
|
|
print("Datasets loaded into hugging face cache directory")
|
|
|
|
|
2024-10-16 13:28:12 -07:00
|
|
|
# data_collator = TruncatingCollator(
|
|
|
|
# max_length=4096
|
|
|
|
# )
|
2024-10-16 13:18:24 -07:00
|
|
|
|
2024-10-16 13:28:12 -07:00
|
|
|
# train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
|
|
|
|
# max_seen_len = 0
|
|
|
|
# for index, entry in tqdm(enumerate(train_dataloader)):
|
|
|
|
# if index == 0:
|
|
|
|
# print(entry)
|
|
|
|
|
|
|
|
# num_input_tokens = entry["input_ids"].shape[1]
|
|
|
|
# max_seen_len = max(max_seen_len, num_input_tokens)
|
2024-10-16 13:18:24 -07:00
|
|
|
|
2024-10-16 13:28:12 -07:00
|
|
|
# print(max_seen_len)
|
2024-10-10 19:57:51 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|