mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-11 02:01:31 +00:00
Merge pull request #52 from rasbt/use-embedding-dropout
Add dropout for embedding layers
This commit is contained in:
commit
b50c42ffbb
@ -113,7 +113,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"execution_count": 2,
|
||||
"id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -181,7 +181,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 3,
|
||||
"id": "794b6b6c-d36f-411e-a7db-8ac566a87fee",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -212,7 +212,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"execution_count": 4,
|
||||
"id": "009238cd-0160-4834-979c-309710986bb0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -279,7 +279,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 5,
|
||||
"id": "79e1b463-dc3f-44ac-9cdb-9d5b6f64eb9d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -314,7 +314,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 6,
|
||||
"id": "9888f79e-8e69-44aa-8a19-cd34292adbf5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -365,7 +365,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 7,
|
||||
"id": "9a1d1bb9-3341-4c9a-bc2a-d2489bf89cda",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -378,8 +378,8 @@
|
||||
" [-0.0189, 0.1121, -1.0876, 1.5173, 0.5647, -1.0876]],\n",
|
||||
" grad_fn=<DivBackward0>)\n",
|
||||
"Mean:\n",
|
||||
" tensor([[ 0.0000],\n",
|
||||
" [ 0.0000]], grad_fn=<MeanBackward1>)\n",
|
||||
" tensor([[2.9802e-08],\n",
|
||||
" [3.9736e-08]], grad_fn=<MeanBackward1>)\n",
|
||||
"Variance:\n",
|
||||
" tensor([[1.],\n",
|
||||
" [1.]], grad_fn=<VarBackward0>)\n"
|
||||
@ -406,7 +406,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 8,
|
||||
"id": "3e06c34b-c68a-4b36-afbe-b30eda4eca39",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -440,7 +440,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 9,
|
||||
"id": "3333a305-aa3d-460a-bcce-b80662d464d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -482,7 +482,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 10,
|
||||
"id": "23b1000a-e613-4b43-bd90-e54deed8d292",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -493,7 +493,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 11,
|
||||
"id": "94c12de2-1cab-46e0-a099-e2e470353bff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -558,7 +558,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 12,
|
||||
"id": "f84694b7-95f3-4323-b6d6-0a73df278e82",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -576,7 +576,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 13,
|
||||
"id": "fc5487d2-2576-4118-80a7-56c4caac2e71",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -626,7 +626,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 14,
|
||||
"id": "9275c879-b148-4579-a107-86827ca14d4d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -647,7 +647,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 15,
|
||||
"id": "7c4976e2-0261-418e-b042-c5be98c2ccaf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -673,7 +673,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 16,
|
||||
"id": "928e7f7c-d0b1-499f-8d07-4cadb428a6f9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -734,7 +734,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 17,
|
||||
"id": "05473938-799c-49fd-86d4-8ed65f94fee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -792,7 +792,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 18,
|
||||
"id": "c75f43cc-6923-4018-b980-26023086572c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -830,7 +830,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 19,
|
||||
"id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -883,7 +883,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 20,
|
||||
"id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -943,7 +943,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": 21,
|
||||
"id": "3fb45a63-b1f3-4b08-b525-dafbc8228405",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -969,7 +969,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 22,
|
||||
"id": "01e737a6-fc99-42bb-9f7e-4da899168811",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1036,7 +1036,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 23,
|
||||
"id": "c61de39c-d03c-4a32-8b57-f49ac3834857",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1061,6 +1061,7 @@
|
||||
" tok_embeds = self.tok_emb(in_idx)\n",
|
||||
" pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
|
||||
" x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n",
|
||||
" x = self.drop_emb(x)\n",
|
||||
" x = self.trf_blocks(x)\n",
|
||||
" x = self.final_norm(x)\n",
|
||||
" logits = self.out_head(x)\n",
|
||||
@ -1075,6 +1076,44 @@
|
||||
"- Using the configuration of the 124M parameter model, we can now instantiate this GPT model with random initial weights as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "ef94fd9c-4e9d-470d-8f8e-dd23d1bb1f64",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Input batch:\n",
|
||||
" tensor([[6109, 3626, 6100, 345],\n",
|
||||
" [6109, 1110, 6622, 257]])\n",
|
||||
"\n",
|
||||
"Output shape: torch.Size([2, 4, 50257])\n",
|
||||
"tensor([[[ 0.6525, 0.5753, 0.0174, ..., 0.2988, 0.1441, 0.0032],\n",
|
||||
" [ 0.0839, -0.6789, -0.6605, ..., -0.2912, 0.4267, -0.2696],\n",
|
||||
" [ 0.8440, 0.1894, 0.0708, ..., 0.0982, -0.2183, 0.0920],\n",
|
||||
" [-0.7958, 0.5066, 0.0209, ..., 0.7497, 0.3233, -0.1251]],\n",
|
||||
"\n",
|
||||
" [[ 0.0181, 0.2606, -0.3022, ..., 0.2940, 0.1998, -0.6246],\n",
|
||||
" [ 0.0596, 0.3041, -0.0293, ..., 0.6796, -0.1226, 0.1303],\n",
|
||||
" [ 1.1895, 1.0891, 0.0237, ..., 0.8299, 0.1794, -0.2250],\n",
|
||||
" [ 0.5457, 0.1861, 0.3872, ..., 1.3537, -0.4062, -0.0268]]],\n",
|
||||
" grad_fn=<UnsafeViewBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"model = GPTModel(GPT_CONFIG_124M)\n",
|
||||
"\n",
|
||||
"out = model(batch)\n",
|
||||
"print(\"Input batch:\\n\", batch)\n",
|
||||
"print(\"\\nOutput shape:\", out.shape)\n",
|
||||
"print(out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
|
@ -336,6 +336,7 @@
|
||||
" tok_embeds = self.tok_emb(in_idx)\n",
|
||||
" pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
|
||||
" x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n",
|
||||
" x = self.drop_emb(x)\n",
|
||||
" x = self.trf_blocks(x)\n",
|
||||
" x = self.final_norm(x)\n",
|
||||
" logits = self.out_head(x)\n",
|
||||
|
@ -202,6 +202,7 @@ class GPTModel(nn.Module):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.drop_emb(x)
|
||||
x = self.trf_blocks(x)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x)
|
||||
|
@ -202,6 +202,7 @@ class GPTModel(nn.Module):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.drop_emb(x)
|
||||
x = self.trf_blocks(x)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user