mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-25 23:11:23 +00:00 
			
		
		
		
	fixed gqa qkv code comments (#660)
This commit is contained in:
		
							parent
							
								
									7632eb018b
								
							
						
					
					
						commit
						c2cfb47b1a
					
				| @ -501,9 +501,9 @@ | |||||||
|     "        ################################################\n", |     "        ################################################\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "        # Transpose keys, values, and queries\n", |     "        # Transpose keys, values, and queries\n", | ||||||
|     "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n", |     "        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", | ||||||
|     "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n", |     "        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", | ||||||
|     "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n", |     "        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "        ##################### NEW #####################\n", |     "        ##################### NEW #####################\n", | ||||||
|     "        # Apply RoPE\n", |     "        # Apply RoPE\n", | ||||||
|  | |||||||
| @ -257,9 +257,9 @@ | |||||||
|     "        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", |     "        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "        # Transpose keys, values, and queries\n", |     "        # Transpose keys, values, and queries\n", | ||||||
|     "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n", |     "        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", | ||||||
|     "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n", |     "        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", | ||||||
|     "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n", |     "        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "        # Apply RoPE\n", |     "        # Apply RoPE\n", | ||||||
|     "        keys = apply_rope(keys, cos, sin)\n", |     "        keys = apply_rope(keys, cos, sin)\n", | ||||||
|  | |||||||
| @ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module): | |||||||
|         values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) |         values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) | ||||||
| 
 | 
 | ||||||
|         # Transpose keys, values, and queries |         # Transpose keys, values, and queries | ||||||
|         keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim) |         keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim) | ||||||
|         values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim) |         values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim) | ||||||
|         queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim) |         queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim) | ||||||
| 
 | 
 | ||||||
|         # Apply RoPE |         # Apply RoPE | ||||||
|         keys = apply_rope(keys, cos, sin) |         keys = apply_rope(keys, cos, sin) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Daniel Kleine
						Daniel Kleine