mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-07 13:27:35 +00:00
add weight sizes
This commit is contained in:
parent
1c173e4f44
commit
83adc4a2ac
File diff suppressed because one or more lines are too long
@ -219,7 +219,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
|
||||||
CHOOSE_MODEL = "gpt2-small"
|
CHOOSE_MODEL = "gpt2-small (124M)"
|
||||||
INPUT_PROMPT = "Every effort moves"
|
INPUT_PROMPT = "Every effort moves"
|
||||||
|
|
||||||
BASE_CONFIG = {
|
BASE_CONFIG = {
|
||||||
@ -230,19 +230,14 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
|
|
||||||
model_configs = {
|
model_configs = {
|
||||||
"gpt2-small": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
||||||
"gpt2-medium": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
||||||
"gpt2-large": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
||||||
"gpt2-xl": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
||||||
}
|
}
|
||||||
|
|
||||||
model_sizes = {
|
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
||||||
"gpt2-small": "124M",
|
|
||||||
"gpt2-medium": "355M",
|
|
||||||
"gpt2-large": "774M",
|
|
||||||
"gpt2-xl": "1558"
|
|
||||||
}
|
|
||||||
|
|
||||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||||
|
|
||||||
main(BASE_CONFIG, INPUT_PROMPT, model_sizes[CHOOSE_MODEL])
|
main(BASE_CONFIG, INPUT_PROMPT, model_size)
|
||||||
|
|||||||
@ -126,13 +126,13 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# allowed model names\n",
|
"# allowed model names\n",
|
||||||
"model_names = {\n",
|
"model_names = {\n",
|
||||||
" \"gpt2-small\": \"openai-community/gpt2\", # 124M\n",
|
" \"gpt2-small (124M)\": \"openai-community/gpt2\",\n",
|
||||||
" \"gpt2-medium\": \"openai-community/gpt2-medium\", # 355M\n",
|
" \"gpt2-medium (355M)\": \"openai-community/gpt2-medium\",\n",
|
||||||
" \"gpt2-large\": \"openai-community/gpt2-large\", # 774M\n",
|
" \"gpt2-large (774M)\": \"openai-community/gpt2-large\",\n",
|
||||||
" \"gpt2-xl\": \"openai-community/gpt2-xl\" # 1558M\n",
|
" \"gpt2-xl (1558M)\": \"openai-community/gpt2-xl\"\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"CHOOSE_MODEL = \"gpt2-small\"\n",
|
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"gpt_hf = GPT2Model.from_pretrained(model_names[CHOOSE_MODEL], cache_dir=\"checkpoints\")\n",
|
"gpt_hf = GPT2Model.from_pretrained(model_names[CHOOSE_MODEL], cache_dir=\"checkpoints\")\n",
|
||||||
"gpt_hf.eval()"
|
"gpt_hf.eval()"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user