Fix qk_norm comment (#769)

This commit is contained in:
Sebastian Raschka 2025-08-15 08:38:48 -05:00 committed by GitHub
parent b14325e56d
commit e9c1c1da38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -436,7 +436,7 @@
" \"n_layers\": 28, # Number of layers\n", " \"n_layers\": 28, # Number of layers\n",
" \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n",
" \"head_dim\": 128, # Size of the heads in GQA\n", " \"head_dim\": 128, # Size of the heads in GQA\n",
" \"qk_norm\": True, # Whether to normalize queries and values in GQA\n", " \"qk_norm\": True, # Whether to normalize queries and keys in GQA\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n", " \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",

View File

@ -22,7 +22,7 @@ QWEN_CONFIG_06_B = {
"n_layers": 28, # Number of layers "n_layers": 28, # Number of layers
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward "hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
"head_dim": 128, # Size of the heads in GQA "head_dim": 128, # Size of the heads in GQA
"qk_norm": True, # Whether to normalize queries and values in GQA "qk_norm": True, # Whether to normalize queries and keys in GQA
"n_kv_groups": 8, # Key-Value groups for grouped-query attention "n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 1_000_000.0, # The base in RoPE's "theta" "rope_base": 1_000_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage