mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 19:10:19 +00:00
Merge pull request #176 from rasbt/dataset-length-warning
Add assertion about data set length
This commit is contained in:
commit
209a103d66
@ -837,7 +837,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 1,
|
||||
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
|
||||
"metadata": {
|
||||
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
|
||||
@ -861,7 +861,13 @@
|
||||
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
|
||||
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
|
||||
"\n",
|
||||
"assert train_dataset.max_length <= BASE_CONFIG[\"context_length\"], (\n",
|
||||
" f\"Dataset length {train_dataset.max_length} exceeds model's context \"\n",
|
||||
" f\"length {BASE_CONFIG['context_length']}. Reinitialize data sets with \"\n",
|
||||
" f\"`max_length={BASE_CONFIG['context_length']}`\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -373,6 +373,12 @@ if __name__ == "__main__":
|
||||
|
||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||
|
||||
assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
|
||||
f"Dataset length {train_dataset.max_length} exceeds model's context "
|
||||
f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
|
||||
f"`max_length={BASE_CONFIG['context_length']}`"
|
||||
)
|
||||
|
||||
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
|
||||
|
@ -548,6 +548,12 @@ if __name__ == "__main__":
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
assert train_dataset.max_length <= model.pos_emb.weight.shape[0], (
|
||||
f"Dataset length {train_dataset.max_length} exceeds model's context "
|
||||
f"length {model.pos_emb.weight.shape[0]}. Reinitialize data sets with "
|
||||
f"`max_length={model.pos_emb.weight.shape[0]}`"
|
||||
)
|
||||
|
||||
###############################
|
||||
# Train model
|
||||
###############################
|
||||
|
Loading…
x
Reference in New Issue
Block a user