mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	Merge pull request #55 from rayedbw/patch-4
Update mha-implementations.ipynb
This commit is contained in:
		
						commit
						083d11fbd0
					
				@ -168,7 +168,7 @@
 | 
			
		||||
    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
 | 
			
		||||
    "        queries, keys, values = qkv.unbind(0)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
 | 
			
		||||
    "        # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
 | 
			
		||||
    "        attn_scores = queries @ keys.transpose(-2, -1)\n",
 | 
			
		||||
    "        attn_scores = attn_scores.masked_fill(\n",
 | 
			
		||||
    "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
 | 
			
		||||
@ -258,12 +258,12 @@
 | 
			
		||||
    "        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
 | 
			
		||||
    "        qkv = qkv.permute(2, 0, 3, 1, 4)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
 | 
			
		||||
    "        q, k, v = qkv.unbind(0)\n",
 | 
			
		||||
    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
 | 
			
		||||
    "        queries, keys, values = qkv.unbind(0)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        use_dropout = 0. if not self.training else self.dropout\n",
 | 
			
		||||
    "        context_vec = torch.nn.functional.scaled_dot_product_attention(q, k, v, \n",
 | 
			
		||||
    "            attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
 | 
			
		||||
    "        context_vec = nn.functional.scaled_dot_product_attention(\n",
 | 
			
		||||
    "            queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
 | 
			
		||||
    "        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
 | 
			
		||||
@ -396,7 +396,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
   "version": "3.10.12"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user