mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-23 05:49:26 +00:00
fixed gqa qkv code comments (#660)
This commit is contained in:
parent
7632eb018b
commit
c2cfb47b1a
@ -501,9 +501,9 @@
|
|||||||
" ################################################\n",
|
" ################################################\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Transpose keys, values, and queries\n",
|
" # Transpose keys, values, and queries\n",
|
||||||
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" ##################### NEW #####################\n",
|
" ##################### NEW #####################\n",
|
||||||
" # Apply RoPE\n",
|
" # Apply RoPE\n",
|
||||||
|
@ -257,9 +257,9 @@
|
|||||||
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
|
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Transpose keys, values, and queries\n",
|
" # Transpose keys, values, and queries\n",
|
||||||
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Apply RoPE\n",
|
" # Apply RoPE\n",
|
||||||
" keys = apply_rope(keys, cos, sin)\n",
|
" keys = apply_rope(keys, cos, sin)\n",
|
||||||
|
@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module):
|
|||||||
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
||||||
|
|
||||||
# Transpose keys, values, and queries
|
# Transpose keys, values, and queries
|
||||||
keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||||
values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||||
queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)
|
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||||
|
|
||||||
# Apply RoPE
|
# Apply RoPE
|
||||||
keys = apply_rope(keys, cos, sin)
|
keys = apply_rope(keys, cos, sin)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user