mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	removed old args in GQA class (#674)
This commit is contained in:
		
							parent
							
								
									ece59ba587
								
							
						
					
					
						commit
						58b8672452
					
				@ -452,10 +452,8 @@
 | 
			
		||||
    "\n",
 | 
			
		||||
    "class GroupedQueryAttention(nn.Module):\n",
 | 
			
		||||
    "    def __init__(\n",
 | 
			
		||||
    "            self, d_in, d_out, context_length, num_heads,\n",
 | 
			
		||||
    "            self, d_in, d_out, num_heads,\n",
 | 
			
		||||
    "            num_kv_groups,       # NEW\n",
 | 
			
		||||
    "            rope_base=10_000,    # NEW\n",
 | 
			
		||||
    "            rope_config=None,    # NEW\n",
 | 
			
		||||
    "            dtype=None\n",
 | 
			
		||||
    "        ):\n",
 | 
			
		||||
    "        super().__init__()\n",
 | 
			
		||||
@ -645,10 +643,8 @@
 | 
			
		||||
    "gqa = GroupedQueryAttention(\n",
 | 
			
		||||
    "    d_in=embed_dim,\n",
 | 
			
		||||
    "    d_out=embed_dim,\n",
 | 
			
		||||
    "    context_length=max_context_len,\n",
 | 
			
		||||
    "    num_heads=num_heads,\n",
 | 
			
		||||
    "    num_kv_groups=8,\n",
 | 
			
		||||
    "    rope_base=llama_3_theta_base\n",
 | 
			
		||||
    ")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "gqa(example_batch)\n",
 | 
			
		||||
@ -753,11 +749,8 @@
 | 
			
		||||
    "        self.att =  GroupedQueryAttention(  # MultiHeadAttention(\n",
 | 
			
		||||
    "            d_in=cfg[\"emb_dim\"],\n",
 | 
			
		||||
    "            d_out=cfg[\"emb_dim\"],\n",
 | 
			
		||||
    "            context_length=cfg[\"context_length\"],\n",
 | 
			
		||||
    "            num_heads=cfg[\"n_heads\"],\n",
 | 
			
		||||
    "            num_kv_groups=cfg[\"n_kv_groups\"],  # NEW\n",
 | 
			
		||||
    "            rope_base=cfg[\"rope_base\"],        # NEW\n",
 | 
			
		||||
    "            rope_config=cfg[\"rope_freq\"],      # NEW\n",
 | 
			
		||||
    "            dtype=cfg[\"dtype\"]\n",
 | 
			
		||||
    "        )\n",
 | 
			
		||||
    "        self.ff = FeedForward(cfg)\n",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user