Merge pull request #176 from rasbt/dataset-length-warning

Add assertion about data set length
This commit is contained in:
Sebastian Raschka 2024-05-23 07:58:47 -04:00 committed by GitHub
commit 209a103d66
3 changed files with 20 additions and 2 deletions

View File

@ -837,7 +837,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 1,
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
"metadata": {
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
@ -861,7 +861,13 @@
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n",
"assert train_dataset.max_length <= BASE_CONFIG[\"context_length\"], (\n",
" f\"Dataset length {train_dataset.max_length} exceeds model's context \"\n",
" f\"length {BASE_CONFIG['context_length']}. Reinitialize data sets with \"\n",
" f\"`max_length={BASE_CONFIG['context_length']}`\"\n",
")"
]
},
{

View File

@ -373,6 +373,12 @@ if __name__ == "__main__":
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
f"Dataset length {train_dataset.max_length} exceeds model's context "
f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
f"`max_length={BASE_CONFIG['context_length']}`"
)
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

View File

@ -548,6 +548,12 @@ if __name__ == "__main__":
drop_last=False,
)
assert train_dataset.max_length <= model.pos_emb.weight.shape[0], (
f"Dataset length {train_dataset.max_length} exceeds model's context "
f"length {model.pos_emb.weight.shape[0]}. Reinitialize data sets with "
f"`max_length={model.pos_emb.weight.shape[0]}`"
)
###############################
# Train model
###############################