minor fixes: Llama 3.2 standalone (#420)

* minor fixes

* reformat rope base as float

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
Daniel Kleine 2024-10-26 04:08:06 +02:00 committed by GitHub
parent 75ede3e340
commit 2b24a7ef30
2 changed files with 10 additions and 10 deletions

View File

@ -907,7 +907,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # NEW: Larger size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
" \"rope_base\": 500_000.0, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
" \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n",
" \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
"}"
@ -2060,7 +2060,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n",
" \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
"}\n",
@ -2073,7 +2073,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
@ -2447,7 +2447,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
@ -2466,7 +2466,7 @@
" \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",

View File

@ -437,7 +437,7 @@
" \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0,\n",
@ -451,13 +451,13 @@
"\n",
"# LLAMA32_CONFIG = {\n",
"# \"vocab_size\": 128_256, # Vocabulary size\n",
"# \"context_length\": 131_000, # Context length\n",
"# \"context_length\": 131_072, # Context length\n",
"# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
"# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"factor\": 32.0,\n",
@ -697,7 +697,6 @@
" def __init__(self, model_path):\n",
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
" num_base_tokens = len(mergeable_ranks)\n",
"\n",
" self.special_tokens = {\n",
" \"<|begin_of_text|>\": 128000,\n",
@ -1013,7 +1012,8 @@
"\n",
"\n",
"load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n",
"model.to(device);"
"model.to(device)\n",
"del combined_weights # free up memory"
]
},
{