diff --git a/ch06/03_bonus_imdb-classification/train-bert-hf.py b/ch06/03_bonus_imdb-classification/train-bert-hf.py index ef3773a..aa207c9 100644 --- a/ch06/03_bonus_imdb-classification/train-bert-hf.py +++ b/ch06/03_bonus_imdb-classification/train-bert-hf.py @@ -236,6 +236,8 @@ if __name__ == "__main__": pad_token_id = tokenizer.encode(tokenizer.pad_token) + base_path = "." + train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id) val_dataset = IMDBDataset(base_path / "val.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id) test_dataset = IMDBDataset(base_path / "test.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)