mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 20:08:08 +00:00
fixed gqa qkv code comments (#660)
This commit is contained in:
parent
7632eb018b
commit
c2cfb47b1a
@ -501,9 +501,9 @@
|
||||
" ################################################\n",
|
||||
"\n",
|
||||
" # Transpose keys, values, and queries\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
"\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # Apply RoPE\n",
|
||||
|
@ -257,9 +257,9 @@
|
||||
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
|
||||
"\n",
|
||||
" # Transpose keys, values, and queries\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
"\n",
|
||||
" # Apply RoPE\n",
|
||||
" keys = apply_rope(keys, cos, sin)\n",
|
||||
|
@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module):
|
||||
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
||||
|
||||
# Transpose keys, values, and queries
|
||||
keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)
|
||||
keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||
values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
keys = apply_rope(keys, cos, sin)
|
||||
|
Loading…
x
Reference in New Issue
Block a user