mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-24 22:39:07 +00:00 
			
		
		
		
	update pr
This commit is contained in:
		
							parent
							
								
									2088d75966
								
							
						
					
					
						commit
						ec6e09136a
					
				| @ -38,7 +38,7 @@ | |||||||
|      "name": "stdout", |      "name": "stdout", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "torch version: 2.2.1\n" |       "torch version: 2.2.2\n" | ||||||
|      ] |      ] | ||||||
|     } |     } | ||||||
|    ], |    ], | ||||||
| @ -228,7 +228,7 @@ | |||||||
|     "            [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n", |     "            [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n", | ||||||
|     "             for _ in range(num_heads)]\n", |     "             for _ in range(num_heads)]\n", | ||||||
|     "        )\n", |     "        )\n", | ||||||
|     "        self.out_proj = nn.Linear(d_in*num_heads, d_out*num_heads)\n", |     "        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "    def forward(self, x):\n", |     "    def forward(self, x):\n", | ||||||
|     "        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n", |     "        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n", | ||||||
| @ -365,6 +365,14 @@ | |||||||
|     "\n", |     "\n", | ||||||
|     "print(\"context_vecs.shape:\", context_vecs.shape)" |     "print(\"context_vecs.shape:\", context_vecs.shape)" | ||||||
|    ] |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "f1d965a5-9b98-4554-8646-7ecd497874cb", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|   } |   } | ||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
| @ -383,7 +391,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.12.3" |    "version": "3.11.4" | ||||||
|   } |   } | ||||||
|  }, |  }, | ||||||
|  "nbformat": 4, |  "nbformat": 4, | ||||||
|  | |||||||
| @ -341,7 +341,7 @@ | |||||||
|     "        self.d_out = d_out\n", |     "        self.d_out = d_out\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", |     "        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", | ||||||
|     "        self.proj = nn.Linear(d_in, d_out)\n", |     "        self.proj = nn.Linear(d_out, d_out)\n", | ||||||
|     "        self.dropout = dropout\n", |     "        self.dropout = dropout\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "    def forward(self, x):\n", |     "    def forward(self, x):\n", | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rasbt
						rasbt