From c35cf65dbf8e9bd790b837ff17ea1d590859d8f3 Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 23 May 2024 06:50:43 -0500 Subject: [PATCH] add assertion about data set length --- ch06/01_main-chapter-code/ch06.ipynb | 10 ++++++++-- ch06/01_main-chapter-code/gpt-class-finetune.py | 6 ++++++ .../additional-experiments.py | 6 ++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/ch06/01_main-chapter-code/ch06.ipynb b/ch06/01_main-chapter-code/ch06.ipynb index 904db06..c0cba26 100644 --- a/ch06/01_main-chapter-code/ch06.ipynb +++ b/ch06/01_main-chapter-code/ch06.ipynb @@ -837,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 1, "id": "2992d779-f9fb-4812-a117-553eb790a5a9", "metadata": { "id": "2992d779-f9fb-4812-a117-553eb790a5a9" @@ -861,7 +861,13 @@ " \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", "}\n", "\n", - "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])" + "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n", + "\n", + "assert train_dataset.max_length <= BASE_CONFIG[\"context_length\"], (\n", + " f\"Dataset length {train_dataset.max_length} exceeds model's context \"\n", + " f\"length {BASE_CONFIG['context_length']}. Reinitialize data sets with \"\n", + " f\"`max_length={BASE_CONFIG['context_length']}`\"\n", + ")" ] }, { diff --git a/ch06/01_main-chapter-code/gpt-class-finetune.py b/ch06/01_main-chapter-code/gpt-class-finetune.py index c440038..bc5666b 100644 --- a/ch06/01_main-chapter-code/gpt-class-finetune.py +++ b/ch06/01_main-chapter-code/gpt-class-finetune.py @@ -373,6 +373,12 @@ if __name__ == "__main__": BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) + assert train_dataset.max_length <= BASE_CONFIG["context_length"], ( + f"Dataset length {train_dataset.max_length} exceeds model's context " + f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with " + f"`max_length={BASE_CONFIG['context_length']}`" + ) + model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")") settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") diff --git a/ch06/02_bonus_additional-experiments/additional-experiments.py b/ch06/02_bonus_additional-experiments/additional-experiments.py index 68deae3..dd3d559 100644 --- a/ch06/02_bonus_additional-experiments/additional-experiments.py +++ b/ch06/02_bonus_additional-experiments/additional-experiments.py @@ -548,6 +548,12 @@ if __name__ == "__main__": drop_last=False, ) + assert train_dataset.max_length <= model.pos_emb.weight.shape[0], ( + f"Dataset length {train_dataset.max_length} exceeds model's context " + f"length {model.pos_emb.weight.shape[0]}. Reinitialize data sets with " + f"`max_length={model.pos_emb.weight.shape[0]}`" + ) + ############################### # Train model ###############################