add weight sizes

This commit is contained in:
rasbt 2024-03-31 08:45:14 -05:00
parent 1c173e4f44
commit 83adc4a2ac
3 changed files with 90 additions and 95 deletions

File diff suppressed because one or more lines are too long

View File

@ -219,7 +219,7 @@ if __name__ == "__main__":
torch.manual_seed(123) torch.manual_seed(123)
CHOOSE_MODEL = "gpt2-small" CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves" INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = { BASE_CONFIG = {
@ -230,19 +230,14 @@ if __name__ == "__main__":
} }
model_configs = { model_configs = {
"gpt2-small": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
} }
model_sizes = { model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
"gpt2-small": "124M",
"gpt2-medium": "355M",
"gpt2-large": "774M",
"gpt2-xl": "1558"
}
BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
main(BASE_CONFIG, INPUT_PROMPT, model_sizes[CHOOSE_MODEL]) main(BASE_CONFIG, INPUT_PROMPT, model_size)

View File

@ -126,13 +126,13 @@
"\n", "\n",
"# allowed model names\n", "# allowed model names\n",
"model_names = {\n", "model_names = {\n",
" \"gpt2-small\": \"openai-community/gpt2\", # 124M\n", " \"gpt2-small (124M)\": \"openai-community/gpt2\",\n",
" \"gpt2-medium\": \"openai-community/gpt2-medium\", # 355M\n", " \"gpt2-medium (355M)\": \"openai-community/gpt2-medium\",\n",
" \"gpt2-large\": \"openai-community/gpt2-large\", # 774M\n", " \"gpt2-large (774M)\": \"openai-community/gpt2-large\",\n",
" \"gpt2-xl\": \"openai-community/gpt2-xl\" # 1558M\n", " \"gpt2-xl (1558M)\": \"openai-community/gpt2-xl\"\n",
"}\n", "}\n",
"\n", "\n",
"CHOOSE_MODEL = \"gpt2-small\"\n", "CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
"\n", "\n",
"gpt_hf = GPT2Model.from_pretrained(model_names[CHOOSE_MODEL], cache_dir=\"checkpoints\")\n", "gpt_hf = GPT2Model.from_pretrained(model_names[CHOOSE_MODEL], cache_dir=\"checkpoints\")\n",
"gpt_hf.eval()" "gpt_hf.eval()"