mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 19:10:19 +00:00
Merge pull request #176 from rasbt/dataset-length-warning
Add assertion about data set length
This commit is contained in:
commit
209a103d66
@ -837,7 +837,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 1,
|
||||||
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
|
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
|
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
|
||||||
@ -861,7 +861,13 @@
|
|||||||
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
|
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
|
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
|
||||||
|
"\n",
|
||||||
|
"assert train_dataset.max_length <= BASE_CONFIG[\"context_length\"], (\n",
|
||||||
|
" f\"Dataset length {train_dataset.max_length} exceeds model's context \"\n",
|
||||||
|
" f\"length {BASE_CONFIG['context_length']}. Reinitialize data sets with \"\n",
|
||||||
|
" f\"`max_length={BASE_CONFIG['context_length']}`\"\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -373,6 +373,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||||
|
|
||||||
|
assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
|
||||||
|
f"Dataset length {train_dataset.max_length} exceeds model's context "
|
||||||
|
f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
|
||||||
|
f"`max_length={BASE_CONFIG['context_length']}`"
|
||||||
|
)
|
||||||
|
|
||||||
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
||||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||||
|
|
||||||
|
@ -548,6 +548,12 @@ if __name__ == "__main__":
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert train_dataset.max_length <= model.pos_emb.weight.shape[0], (
|
||||||
|
f"Dataset length {train_dataset.max_length} exceeds model's context "
|
||||||
|
f"length {model.pos_emb.weight.shape[0]}. Reinitialize data sets with "
|
||||||
|
f"`max_length={model.pos_emb.weight.shape[0]}`"
|
||||||
|
)
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Train model
|
# Train model
|
||||||
###############################
|
###############################
|
||||||
|
Loading…
x
Reference in New Issue
Block a user