fixed gqa qkv code comments (#660)

This commit is contained in:
Daniel Kleine 2025-06-13 15:21:28 +02:00 committed by GitHub
parent 7632eb018b
commit c2cfb47b1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 9 deletions

View File

@ -501,9 +501,9 @@
" ################################################\n",
"\n",
" # Transpose keys, values, and queries\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
"\n",
" ##################### NEW #####################\n",
" # Apply RoPE\n",

View File

@ -257,9 +257,9 @@
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
"\n",
" # Transpose keys, values, and queries\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
"\n",
" # Apply RoPE\n",
" keys = apply_rope(keys, cos, sin)\n",

View File

@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module):
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# Transpose keys, values, and queries
keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)
keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
# Apply RoPE
keys = apply_rope(keys, cos, sin)