Merge pull request #52 from rasbt/use-embedding-dropout

Add dropout for embedding layers
This commit is contained in:
Sebastian Raschka 2024-03-04 07:07:46 -06:00 committed by GitHub
commit b50c42ffbb
4 changed files with 66 additions and 24 deletions

View File

@ -113,7 +113,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 2,
"id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227", "id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -181,7 +181,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 3,
"id": "794b6b6c-d36f-411e-a7db-8ac566a87fee", "id": "794b6b6c-d36f-411e-a7db-8ac566a87fee",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -212,7 +212,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 4,
"id": "009238cd-0160-4834-979c-309710986bb0", "id": "009238cd-0160-4834-979c-309710986bb0",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -279,7 +279,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 5,
"id": "79e1b463-dc3f-44ac-9cdb-9d5b6f64eb9d", "id": "79e1b463-dc3f-44ac-9cdb-9d5b6f64eb9d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -314,7 +314,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 6,
"id": "9888f79e-8e69-44aa-8a19-cd34292adbf5", "id": "9888f79e-8e69-44aa-8a19-cd34292adbf5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -365,7 +365,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 7,
"id": "9a1d1bb9-3341-4c9a-bc2a-d2489bf89cda", "id": "9a1d1bb9-3341-4c9a-bc2a-d2489bf89cda",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -378,8 +378,8 @@
" [-0.0189, 0.1121, -1.0876, 1.5173, 0.5647, -1.0876]],\n", " [-0.0189, 0.1121, -1.0876, 1.5173, 0.5647, -1.0876]],\n",
" grad_fn=<DivBackward0>)\n", " grad_fn=<DivBackward0>)\n",
"Mean:\n", "Mean:\n",
" tensor([[ 0.0000],\n", " tensor([[2.9802e-08],\n",
" [ 0.0000]], grad_fn=<MeanBackward1>)\n", " [3.9736e-08]], grad_fn=<MeanBackward1>)\n",
"Variance:\n", "Variance:\n",
" tensor([[1.],\n", " tensor([[1.],\n",
" [1.]], grad_fn=<VarBackward0>)\n" " [1.]], grad_fn=<VarBackward0>)\n"
@ -406,7 +406,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 8,
"id": "3e06c34b-c68a-4b36-afbe-b30eda4eca39", "id": "3e06c34b-c68a-4b36-afbe-b30eda4eca39",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -440,7 +440,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 9,
"id": "3333a305-aa3d-460a-bcce-b80662d464d9", "id": "3333a305-aa3d-460a-bcce-b80662d464d9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -482,7 +482,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 10,
"id": "23b1000a-e613-4b43-bd90-e54deed8d292", "id": "23b1000a-e613-4b43-bd90-e54deed8d292",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -493,7 +493,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 11,
"id": "94c12de2-1cab-46e0-a099-e2e470353bff", "id": "94c12de2-1cab-46e0-a099-e2e470353bff",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -558,7 +558,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 12,
"id": "f84694b7-95f3-4323-b6d6-0a73df278e82", "id": "f84694b7-95f3-4323-b6d6-0a73df278e82",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -576,7 +576,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 13,
"id": "fc5487d2-2576-4118-80a7-56c4caac2e71", "id": "fc5487d2-2576-4118-80a7-56c4caac2e71",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -626,7 +626,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 14,
"id": "9275c879-b148-4579-a107-86827ca14d4d", "id": "9275c879-b148-4579-a107-86827ca14d4d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -647,7 +647,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 15,
"id": "7c4976e2-0261-418e-b042-c5be98c2ccaf", "id": "7c4976e2-0261-418e-b042-c5be98c2ccaf",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -673,7 +673,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 16,
"id": "928e7f7c-d0b1-499f-8d07-4cadb428a6f9", "id": "928e7f7c-d0b1-499f-8d07-4cadb428a6f9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -734,7 +734,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 17,
"id": "05473938-799c-49fd-86d4-8ed65f94fee6", "id": "05473938-799c-49fd-86d4-8ed65f94fee6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -792,7 +792,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 18,
"id": "c75f43cc-6923-4018-b980-26023086572c", "id": "c75f43cc-6923-4018-b980-26023086572c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -830,7 +830,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 19,
"id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6", "id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -883,7 +883,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 20,
"id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9", "id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -943,7 +943,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 64, "execution_count": 21,
"id": "3fb45a63-b1f3-4b08-b525-dafbc8228405", "id": "3fb45a63-b1f3-4b08-b525-dafbc8228405",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -969,7 +969,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 22,
"id": "01e737a6-fc99-42bb-9f7e-4da899168811", "id": "01e737a6-fc99-42bb-9f7e-4da899168811",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1036,7 +1036,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 23,
"id": "c61de39c-d03c-4a32-8b57-f49ac3834857", "id": "c61de39c-d03c-4a32-8b57-f49ac3834857",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1061,6 +1061,7 @@
" tok_embeds = self.tok_emb(in_idx)\n", " tok_embeds = self.tok_emb(in_idx)\n",
" pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n", " pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
" x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n", " x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n",
" x = self.drop_emb(x)\n",
" x = self.trf_blocks(x)\n", " x = self.trf_blocks(x)\n",
" x = self.final_norm(x)\n", " x = self.final_norm(x)\n",
" logits = self.out_head(x)\n", " logits = self.out_head(x)\n",
@ -1075,6 +1076,44 @@
"- Using the configuration of the 124M parameter model, we can now instantiate this GPT model with random initial weights as follows:" "- Using the configuration of the 124M parameter model, we can now instantiate this GPT model with random initial weights as follows:"
] ]
}, },
{
"cell_type": "code",
"execution_count": 24,
"id": "ef94fd9c-4e9d-470d-8f8e-dd23d1bb1f64",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input batch:\n",
" tensor([[6109, 3626, 6100, 345],\n",
" [6109, 1110, 6622, 257]])\n",
"\n",
"Output shape: torch.Size([2, 4, 50257])\n",
"tensor([[[ 0.6525, 0.5753, 0.0174, ..., 0.2988, 0.1441, 0.0032],\n",
" [ 0.0839, -0.6789, -0.6605, ..., -0.2912, 0.4267, -0.2696],\n",
" [ 0.8440, 0.1894, 0.0708, ..., 0.0982, -0.2183, 0.0920],\n",
" [-0.7958, 0.5066, 0.0209, ..., 0.7497, 0.3233, -0.1251]],\n",
"\n",
" [[ 0.0181, 0.2606, -0.3022, ..., 0.2940, 0.1998, -0.6246],\n",
" [ 0.0596, 0.3041, -0.0293, ..., 0.6796, -0.1226, 0.1303],\n",
" [ 1.1895, 1.0891, 0.0237, ..., 0.8299, 0.1794, -0.2250],\n",
" [ 0.5457, 0.1861, 0.3872, ..., 1.3537, -0.4062, -0.0268]]],\n",
" grad_fn=<UnsafeViewBackward0>)\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"model = GPTModel(GPT_CONFIG_124M)\n",
"\n",
"out = model(batch)\n",
"print(\"Input batch:\\n\", batch)\n",
"print(\"\\nOutput shape:\", out.shape)\n",
"print(out)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 44,

View File

@ -336,6 +336,7 @@
" tok_embeds = self.tok_emb(in_idx)\n", " tok_embeds = self.tok_emb(in_idx)\n",
" pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n", " pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
" x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n", " x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n",
" x = self.drop_emb(x)\n",
" x = self.trf_blocks(x)\n", " x = self.trf_blocks(x)\n",
" x = self.final_norm(x)\n", " x = self.final_norm(x)\n",
" logits = self.out_head(x)\n", " logits = self.out_head(x)\n",

View File

@ -202,6 +202,7 @@ class GPTModel(nn.Module):
tok_embeds = self.tok_emb(in_idx) tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x) x = self.trf_blocks(x)
x = self.final_norm(x) x = self.final_norm(x)
logits = self.out_head(x) logits = self.out_head(x)

View File

@ -202,6 +202,7 @@ class GPTModel(nn.Module):
tok_embeds = self.tok_emb(in_idx) tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x) x = self.trf_blocks(x)
x = self.final_norm(x) x = self.final_norm(x)
logits = self.out_head(x) logits = self.out_head(x)