mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-14 16:47:33 +00:00
use need_weights=False
This commit is contained in:
parent
5643c88db9
commit
29ca41799a
@ -19,14 +19,14 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
||||
"outputId": "02205088-47f1-4fc1-83a4-dd0be4cd64dd"
|
||||
"outputId": "999a54ca-36b5-4e26-e25a-94953b4d1590"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running on cuda\n"
|
||||
"Running on cpu\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -62,12 +62,12 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
|
||||
"outputId": "a1eefc3c-21ea-463e-e75e-06af9f6262dd"
|
||||
"outputId": "4d694519-8283-49a6-b27f-c2ae06a2fa4e"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 9216])\n"
|
||||
]
|
||||
@ -108,12 +108,12 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
||||
"outputId": "c66ee5fd-b0cd-4ab4-e097-4d64902ea0d0"
|
||||
"outputId": "f7cd4e03-5622-4dd6-a9e0-31b80e15838b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
@ -172,12 +172,12 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
||||
"outputId": "9c4ffbe8-6684-429c-b86a-b68121341a4c"
|
||||
"outputId": "86c643c3-c1dd-4021-d7e6-b79e3cea3d99"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
@ -338,12 +338,12 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
|
||||
"outputId": "027b5a66-4e17-49e8-9e80-9c70eaf201ab"
|
||||
"outputId": "ff45b9c9-19a7-4769-efaf-b9f417a31631"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
@ -392,20 +392,23 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
|
||||
"outputId": "9d9afbbd-2e85-44cb-afc9-8cb3c91e8368"
|
||||
"outputId": "1a6ce118-be18-4d7b-99bc-a0c9bb78f8b1"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MHAPyTorchClass(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.block_size = block_size\n",
|
||||
@ -415,9 +418,10 @@
|
||||
" dropout=dropout,\n",
|
||||
" bias=qkv_bias,\n",
|
||||
" add_bias_kv=qkv_bias,\n",
|
||||
" batch_first=True\n",
|
||||
" batch_first=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.need_weights = need_weights\n",
|
||||
" self.proj = nn.Linear(d_out, d_out)\n",
|
||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n",
|
||||
"\n",
|
||||
@ -432,14 +436,16 @@
|
||||
" attn_mask = self.mask[:self.block_size, :self.block_size]\n",
|
||||
"\n",
|
||||
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
|
||||
" attn_output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask)\n",
|
||||
" attn_output, _ = self.multihead_attn(\n",
|
||||
" x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" output = self.proj(attn_output)\n",
|
||||
"\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"mha_pytorch_class = MHAPyTorchClass(\n",
|
||||
"mha_pytorch_class_default = MHAPyTorchClass(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
@ -448,7 +454,65 @@
|
||||
" qkv_bias=False\n",
|
||||
").to(device)\n",
|
||||
"\n",
|
||||
"out = mha_pytorch_class(embeddings)\n",
|
||||
"out = mha_pytorch_class_default(embeddings)\n",
|
||||
"print(out.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2164859-31a0-4537-b4fb-27d57675ba77",
|
||||
"metadata": {
|
||||
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
|
||||
},
|
||||
"source": [
|
||||
"- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
|
||||
"\n",
|
||||
"> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
|
||||
" Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
|
||||
" and achieve the best performance for MHA.\n",
|
||||
" Default: ``True``."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
|
||||
"outputId": "ef1a0698-5b18-426d-df0f-00bdbaeeaccc"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False,\n",
|
||||
" need_weights=False # NEW!\n",
|
||||
").to(device)\n",
|
||||
"\n",
|
||||
"out = mha_pytorch_class_noweights(embeddings)\n",
|
||||
"print(out.shape)"
|
||||
]
|
||||
},
|
||||
@ -459,12 +523,12 @@
|
||||
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
|
||||
},
|
||||
"source": [
|
||||
"## Speed comparison"
|
||||
"## Speed comparison (M1 Macbook Air CPU)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -475,10 +539,10 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"41.1 ms ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||
"914 ms ± 50.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -489,7 +553,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -500,10 +564,10 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"6.58 ms ± 143 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||
"252 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -514,7 +578,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 11,
|
||||
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -525,10 +589,10 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"7.19 ms ± 294 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||
"300 ms ± 8.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -539,7 +603,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 12,
|
||||
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -550,10 +614,10 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.37 ms ± 432 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
|
||||
"94.2 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -564,7 +628,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -575,16 +639,193 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"6.66 ms ± 397 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||
"297 ms ± 2.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 5) Using PyTorch's torch.nn.MultiheadAttention\n",
|
||||
"%timeit mha_pytorch_class(embeddings)"
|
||||
"%timeit mha_pytorch_class_default(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"274 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
|
||||
"%timeit mha_pytorch_class_noweights(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9",
|
||||
"metadata": {
|
||||
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
|
||||
},
|
||||
"source": [
|
||||
"## Speed comparison (Nvidia A100 GPU)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "4d21edc6",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
|
||||
"outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"914 ms ± 50.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
|
||||
"%timeit mha_ch03_wrapper(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "98fda51b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
||||
"outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"252 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 2) The multi-head attention class from chapter 3\n",
|
||||
"%timeit mha_ch03(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "af8d11a7",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
||||
"outputId": "92b634f8-43f8-468f-87a1-bb774b64c212"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"300 ms ± 8.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 3) An alternative multi-head attention with combined weights\n",
|
||||
"%timeit mha_combined_qkv(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "de1e9a77",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
||||
"outputId": "80c6e314-0771-470e-b090-628984ce2d85"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"94.2 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
|
||||
"%timeit mha_pytorch_scaled(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "481e3fea",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
|
||||
"outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"297 ms ± 2.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 5) Using PyTorch's torch.nn.MultiheadAttention\n",
|
||||
"%timeit mha_pytorch_class_default(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "9d52a9eb",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"274 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
|
||||
"%timeit mha_pytorch_class_noweights(embeddings)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user