From c2cfb47b1a71d13818a49443dcfca49b5ff36efb Mon Sep 17 00:00:00 2001 From: Daniel Kleine <53251018+d-kleine@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:21:28 +0200 Subject: [PATCH] fixed gqa qkv code comments (#660) --- ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb | 6 +++--- ch05/07_gpt_to_llama/standalone-llama32.ipynb | 6 +++--- pkg/llms_from_scratch/llama3.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 908a034..7766dca 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -501,9 +501,9 @@ " ################################################\n", "\n", " # Transpose keys, values, and queries\n", - " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", + " keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", "\n", " ##################### NEW #####################\n", " # Apply RoPE\n", diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index dbec8ad..afb27c2 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -257,9 +257,9 @@ " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", "\n", " # Transpose keys, values, and queries\n", - " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", + " keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", "\n", " # Apply RoPE\n", " keys = apply_rope(keys, cos, sin)\n", diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index df7bc72..785e8af 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module): values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) # Transpose keys, values, and queries - keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) - values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) - queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim) + keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim) + values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim) + queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) # Apply RoPE keys = apply_rope(keys, cos, sin)