mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-02 19:00:14 +00:00
add weight sizes
This commit is contained in:
parent
1c173e4f44
commit
83adc4a2ac
File diff suppressed because one or more lines are too long
@ -219,7 +219,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
CHOOSE_MODEL = "gpt2-small"
|
||||
CHOOSE_MODEL = "gpt2-small (124M)"
|
||||
INPUT_PROMPT = "Every effort moves"
|
||||
|
||||
BASE_CONFIG = {
|
||||
@ -230,19 +230,14 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
model_configs = {
|
||||
"gpt2-small": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
||||
"gpt2-medium": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
||||
"gpt2-large": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
||||
"gpt2-xl": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
||||
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
||||
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
||||
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
||||
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
||||
}
|
||||
|
||||
model_sizes = {
|
||||
"gpt2-small": "124M",
|
||||
"gpt2-medium": "355M",
|
||||
"gpt2-large": "774M",
|
||||
"gpt2-xl": "1558"
|
||||
}
|
||||
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
|
||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||
|
||||
main(BASE_CONFIG, INPUT_PROMPT, model_sizes[CHOOSE_MODEL])
|
||||
main(BASE_CONFIG, INPUT_PROMPT, model_size)
|
||||
|
||||
@ -126,13 +126,13 @@
|
||||
"\n",
|
||||
"# allowed model names\n",
|
||||
"model_names = {\n",
|
||||
" \"gpt2-small\": \"openai-community/gpt2\", # 124M\n",
|
||||
" \"gpt2-medium\": \"openai-community/gpt2-medium\", # 355M\n",
|
||||
" \"gpt2-large\": \"openai-community/gpt2-large\", # 774M\n",
|
||||
" \"gpt2-xl\": \"openai-community/gpt2-xl\" # 1558M\n",
|
||||
" \"gpt2-small (124M)\": \"openai-community/gpt2\",\n",
|
||||
" \"gpt2-medium (355M)\": \"openai-community/gpt2-medium\",\n",
|
||||
" \"gpt2-large (774M)\": \"openai-community/gpt2-large\",\n",
|
||||
" \"gpt2-xl (1558M)\": \"openai-community/gpt2-xl\"\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"CHOOSE_MODEL = \"gpt2-small\"\n",
|
||||
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
|
||||
"\n",
|
||||
"gpt_hf = GPT2Model.from_pretrained(model_names[CHOOSE_MODEL], cache_dir=\"checkpoints\")\n",
|
||||
"gpt_hf.eval()"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user