From 569f6bc7f011412874c7971fbf29eb7e55f06823 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 13 Mar 2024 07:12:10 -0500 Subject: [PATCH 1/3] benchmark numbers --- .../mha-implementations.ipynb | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index b7b27df..1eda8cc 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -29,7 +29,7 @@ "base_uri": "https://localhost:8080/" }, "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "outputId": "2ddf0145-94d3-4490-8087-d1ffeb6f30ab" + "outputId": "7d088260-3fa1-44f2-bd65-2a46e289f9d4" }, "outputs": [ { @@ -74,14 +74,14 @@ "base_uri": "https://localhost:8080/" }, "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "outputId": "ae6d707f-eae8-467a-ed4d-a88051bf776f" + "outputId": "f8a33752-2cd6-4101-8feb-9d1699984719" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "torch.Size([8, 1024, 9216])\n" + "torch.Size([8, 1024, 768])\n" ] } ], @@ -120,7 +120,7 @@ "base_uri": "https://localhost:8080/" }, "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "outputId": "5df88462-8b1a-4b1f-ce71-3909ad2ca9c2" + "outputId": "b704a040-3547-422c-ecda-df9982a2da35" }, "outputs": [ { @@ -184,7 +184,7 @@ "base_uri": "https://localhost:8080/" }, "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "outputId": "1240afaf-139a-4d01-ddac-4a186ff4a4fd" + "outputId": "5d948671-176f-4633-bede-97767e36becc" }, "outputs": [ { @@ -350,7 +350,7 @@ "base_uri": "https://localhost:8080/" }, "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "outputId": "83ef0a2f-3fe6-4123-c8de-f481f2a9e415" + "outputId": "af9e4855-7f20-4d61-8532-4827df8dfb30" }, "outputs": [ { @@ -404,7 +404,7 @@ "base_uri": "https://localhost:8080/" }, "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "outputId": "aabf134e-c9bc-474b-ee57-0c24b5fb604c" + "outputId": "2a085df8-0445-4818-9978-6dc74469f568" }, "outputs": [ { @@ -504,7 +504,7 @@ "base_uri": "https://localhost:8080/" }, "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", - "outputId": "5b577a7c-4199-4e52-8d08-a0974a5a3685" + "outputId": "234771f4-8a53-4478-8a9b-cf19f79a5e07" }, "outputs": [ { @@ -537,7 +537,7 @@ "id": "8877de71-f84f-4f6d-bc87-7552013b6301" }, "source": [ - "## Quick speed comparison (M1 Macbook Air CPU)" + "## Quick speed comparison (M3 Macbook Air CPU)" ] }, { @@ -556,7 +556,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.15 s ± 86.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "200 ms ± 5.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -581,7 +581,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "273 ms ± 3.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "198 ms ± 6.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -606,7 +606,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "324 ms ± 17.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "236 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -631,7 +631,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "106 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "71.6 ms ± 3.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -656,7 +656,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "351 ms ± 7.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "217 ms ± 4.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -670,16 +670,16 @@ "execution_count": null, "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", "metadata": { - "tags": [], "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", - "outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2" + "outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2", + "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "333 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "205 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -707,14 +707,14 @@ "base_uri": "https://localhost:8080/" }, "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0", - "outputId": "07a711f6-f7ff-496c-ce16-be67308aeadf" + "outputId": "e99a17e9-8139-4b04-dac8-fa1dd5027735" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "41.1 ms ± 5.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "8.35 ms ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -732,14 +732,14 @@ "base_uri": "https://localhost:8080/" }, "id": "8686dd69-3655-40e4-a57b-a2c55532a010", - "outputId": "b0c29336-55e8-4194-89e4-9201f77e5375" + "outputId": "5553b42c-b709-41a4-8a8b-be36dae408ab" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "6.58 ms ± 256 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "6.59 ms ± 231 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -757,14 +757,14 @@ "base_uri": "https://localhost:8080/" }, "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf", - "outputId": "ba357440-47d4-450d-b859-08031056ccf8" + "outputId": "01b0da88-510b-4b21-919a-0a7519a55ed8" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "7.19 ms ± 590 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "7.21 ms ± 716 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -782,14 +782,14 @@ "base_uri": "https://localhost:8080/" }, "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26", - "outputId": "b2126630-7fae-4c44-8180-226ff5509d78" + "outputId": "542706db-5041-45ca-f667-9e1bd1c2c7aa" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "2.37 ms ± 569 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" + "2.38 ms ± 362 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" ] } ], @@ -807,14 +807,14 @@ "base_uri": "https://localhost:8080/" }, "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d", - "outputId": "453d9b7b-3f45-4907-b4fd-77d395534d6b" + "outputId": "13cfc808-2b11-4041-fe67-e5a63abe4f28" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "6.66 ms ± 301 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "6.67 ms ± 408 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -832,14 +832,14 @@ "base_uri": "https://localhost:8080/" }, "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8", - "outputId": "ccfe127c-c069-4dcd-f37d-ea6a40406955" + "outputId": "c52858e7-999c-4782-adc9-731f8d69dfa6" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "4.52 ms ± 317 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "4.54 ms ± 7.17 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -860,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000", "metadata": { "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000" @@ -892,6 +892,28 @@ }, { "cell_type": "code", + "execution_count": 16, + "id": "CDJAPZaszaqx", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 489 + }, + "id": "CDJAPZaszaqx", + "outputId": "f23e9b83-7fd6-4011-9434-0e6934cf762a" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ], "source": [ "\n", "import matplotlib.pyplot as plt\n", @@ -945,28 +967,6 @@ "plt.tight_layout()\n", "plt.savefig(\"1.pdf\")\n", "plt.show()\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 488 - }, - "id": "CDJAPZaszaqx", - "outputId": "47c9ef93-438e-4455-faef-7e253eaaaa8d" - }, - "id": "CDJAPZaszaqx", - "execution_count": 11, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ], - "image/png": "\n" - }, - "metadata": {} - } ] } ], From f2c8eeb6b8df47157a4ef9692e0c4bae43e445e9 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 13 Mar 2024 08:34:39 -0500 Subject: [PATCH 2/3] pretraining on project gutenberg --- .../2.pdf | Bin 0 -> 16780 bytes .../mha-implementations-Copy1.ipynb | 850 ++++++++++++++++++ .../hparam_search.py | 0 .../previous_chapters.py | 0 .../the-verdict.txt | 0 .../README.md | 121 +++ .../prepare_dataset.py | 66 ++ .../pretraining_simple.py | 212 +++++ .../previous_chapters.py | 313 +++++++ 9 files changed, 1562 insertions(+) create mode 100644 ch03/02_bonus_efficient-multihead-attention/2.pdf create mode 100644 ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb rename ch05/{02_hparam_tuning => 02_bonus_hparam_tuning}/hparam_search.py (100%) rename ch05/{02_hparam_tuning => 02_bonus_hparam_tuning}/previous_chapters.py (100%) rename ch05/{02_hparam_tuning => 02_bonus_hparam_tuning}/the-verdict.txt (100%) create mode 100644 ch05/03_bonus_pretraining_on_gutenberg/README.md create mode 100644 ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py create mode 100644 ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py create mode 100644 ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py diff --git a/ch03/02_bonus_efficient-multihead-attention/2.pdf b/ch03/02_bonus_efficient-multihead-attention/2.pdf new file mode 100644 index 0000000000000000000000000000000000000000..06ef06cb0c6d0e3418038a49286766386157f48a GIT binary patch literal 16780 zcmb`v2{={V7eB6$T=SelS4f0AU#=L~Zu4O-q_6gXzN z1Hx**(_s>k0+(Gm)eNL)n^7DoBsg}Z!Pt>PA^Ca1anLCYVdm!O?C#|X$FE-X^Cp^+ zC~#XqRznN0LJFk75xSm$1+~?q`sz^^e(*1LQ0oPN_=dRmBLVJK*hiR>$lm^bL?92u zKcH_&a&~uA^9}?Qk>C%DN6N^-F<6`o9t@9A2ZDjS$#B97d#YYu-rx)cb^Tiz(B(hc zGaz}nQrzIERf<~f9zb3=LdydvL4!o}b|yh=lPP{AM^9K_cBf^lk*GHN-3iO_MCF8s zN$>A7Z0mA}Bee!51S5GL7*?FkB25Yrm+sp0+FGCvYVk58s2nK28WzZbNE=+J|M~vJ zt6w*2iNnEY+xp1NVrByjqm7d3BrRGL+6Up3vtMt=lpmK}lt}ENLzma`>I)5VbhDXE* zqbf;~uln6T%XaO2tOZ9g&&QUFI(<6Xa_@dBPx|&d7DhzZkbC~WInVU4>wQ!It;a9p z_Dz0#c+|$uWdBv1Oxy+DQ|}i|isHjedpVXOW-^k}IhMrjLK%!CBF^j^;2-hMImMg4 znOi!aw|rCub6m)OVGE*(@rcZ==Gy@_i-}9mt9pyWo`l&mOB$T6HKBi)GR+>Nk}ov) zVEjQin+wC(%<~gd1EQBR(6NGp9@kEZMreqV5RojEM-zk37+(ugxQA5u>GGb|c6aTm zr(K=Dh+k`mM#3|HO+|SgKTf6u4-NItkGfT75A2)%t^M`n@N{NQf+O?uNr|U-E>z8x zm^f;6KNA_)exF`lVEdO=^=ML9>>&BrQLTmSuxL5W2)98F+r?*lROU9RrKcooaNP}W z5GA+xhJ`4Lj-Bv#nXu$Lm17^!95B;t|H1R<$)KtQEho|NGYba!%CZgi2vou2{1fUb z5+ayUbQV$TrN$|%#uCo2(iIn%1}<91Pu{-oWjE9zWn*xy?<=M@d#7~GSmRj9v)A9p zxGH`97yawy6>Z9D+Mmro6MNur+Oa5B=%&5f;76Oq?|O!0&m%2G=j;sP`}OiKO9@;# zEDeX5jUS0vSg33sYJJkzboe{I7GaLrGNg`#4hi%yU z5cl%PcGiJ;B!j!X?i$H<`{>o)Ku% zKo-3#wYexOAfF?i_nCgXI?LAGT@C5tbM*CrT9@*-bRD`9{94x@{rrxAae}&i&IsRJ z5|3}Izf&{zL#)AkUJ?fl^OB?8_PKk(CZXdn%_k9pPEXTS+oN7N9;viR9ml#_+!blg zxLd3t>5nzqc4A4E+tzMc>}BV@=k-L(3BC(T{dcT{I$kX8{xLQrack%AM%oYZoyrP^ zS^LMnPBe`y#!MiTs#p)*{rNjnI;@LLpJ{e)yZiYqTem#l^_+Y0{<~(YVA>xg*+~Za zYNw9{-fXFtD^)C3q3O=$jDPE9MX$i5ws_9~R!RHWPQyp5!r!v2MEFII#;F<%e_y7- zD9l&l>_SubC-xF;QDgbgE5|jgi8r<=aU;b#VtaSyTM_ixX?hMV`0v)Igb^h&%=dt-AJQ*9EyV6bgs?e>A~(l)Z(Q~< zGcYwOI?8;wP*zZ3QmpeAwlzp*`!_btrKX_HJu6QeZJ+f!ZN*q=~l_7-rWBFo=# zYSlHPF}MxzXrk*ygW6gAVK3lseXYwlW8XG?J4H(CUt46!q zX}3=ydlV-btChF;60cZabJBS~WO}#0@5@EWHtpv_Qn!r;di93#T&H#I0uJr>bhqE; z&n-)z$huRZj$wY+TPYd&M6Vy`C8AObgyp0|cE;qIcd?dp>^|ut8rYU&G?S05`=Q$N z{qAv9ZLi(VlDCbuBy&o!kIv>U#JTM*o>>ab`F^+lH;uZXc{LTCQ3+>q8<ov;=o)qFtGA0jcSYgCo zL=@OBv?CictXitdWi7Ay)$DZx-@s$lz+?#<;7Jx|y26vo<=y|`iR9+a>hdPUdiSF(Drv>V5}ZuCn0lBT z?~gm&tcW4>($O)AGF0IiTx?k_w^(kcd7dJ`Ooc57eymZI9QJ>+g^N?=Y(Ua!&KkYm zDy%wqrDsZQ=AN`}&Ff>~<$3+xuDm#PVsk8fB9E`-cC2Q&gl?d~6}6~9`5)DGC*5!1 z0?KP;C`BEx*nJ`+p+9N66}C%J;SGX{Yj`6d|2J<+@$Hb?P@gflzxB1*=AY&V*jj)0 zojwtCpUt-A4kIt~&iogVskGi<72^I52M^KmsPXbUkH1vli`8@Cy_dDSnCq#i?<Td#1t$PuaF- zQo^~3?FbD^%5-W}^P`lwSvTVT2x+@Fle06Ey+PN$A~|Nh(hjRCJ5v!E0_1B%M&mZ* zDD}NGP+MP4aN9W)sLY8@WFGF-Jig@V`Sdm;?@nU!3o)}@Zwjjl)AXua)8qZ^(1s_Q z(MO#5llDIJ(Dg{3?Jnco){{Lt;$SybN@JaQUA=3H|GP${RP%UJ+fdleWX1a~rCW5o zjz9JYyLnjcQEgbX{(ve+cJ&P%wcPZc@ro}(!y>#I(s!Kll$OqZ`mWO4ubCnbh4NME+Kw`6M59xc9^3rhxdqh+*=n zaHee)3(omP?^u5fXJ%De&xMPU>TILW^80UXy!+#XS8RH_>btp?_1M;dqIHgme1{;xmy*GCb}%tX^RCE(Gx@E&&lOPhSQXq6@GtoF<$QXLEh<+rux{L@D=8}p`I=vM3@Ih)d(oo6L!}%k!``qk zkn^=*PffG;=wCq%>GYq*_=$gX zd-S%qWi;XUt2jFSbJyI2zmAl9J(qE)Gn3bJayIfSb&ow@G;7S=Q{8x%nf<$mrn&?< z@!323#>Y%Y(mZ$r^%m%b@`rZovsw(iEHauYQyl1?-D%TD*-1tBR6;r$yMcJ1nP})o zFT6EkHYX49>FT(pQdfbgwZX@GES+_tdu<}$()MiT(*DG$J{^Vk}<_mNX{37aXV%d^CN%-NQDF`@3GL_=QW@NshKpM^lcR@aNiJUNLUoDQweO#lgIT2*4%Yh6(QQGR(u^Hc;T%^NbMg3UnGi~GEsv=f(N6d*RvKfJzBSa6*CE@aSpLF%JT$I`$sr{gA zM3*}>{q>=Mu=X8~Nc&hZJL}Rd9v3`HW^U2MD=BiI`^DJhk+oa=`(0kd)C<3^6fy18 zO6%$4{Pt9muFb4!cq)w3vt)t%vy{p6=L;^QfaC9U-*81wq?I<1#(0vOthOkbRO_RH z6W)2`is#ntxft%Dd^lPoeUIx%%>Da&x>={IKe&z_5OzB_bj6ccq@~x~bC1Q)Flp<& z2BT{g_9O4_Q(XY8{$w)}Hs_$CFe;b`joAQ##O)^}(FrrRAB$8*9@zF$<=&MmRrvn- zQ#DDMMhTyUTYIWQoF1eX7F=otPipGKQMT1$?GcflAh?E{gBY=b@m`vMVx=x6*DVsWG-_ArK&Ln zPBt(2)*6p~J1Lhao)bIs z#o#*_d7qm&eawCnVShJ(B6MpHCs{2^RfbErb4v0udzNu&W&R~o6Q5$`E0d)cPrb3X zlcu6CDqszb-2jMGGFFJ!5N3W)h(vr6{UYKcq*WsAGBbRQHcP_T;A9m>w3{PTn;Df- zJ|00Q$Shdeyqwm8Xz)RFk;g5sZqAAmt@rhjWg4vjTj?z7+C5GLE zpATRWK_P_DgJ*YqrZ3i_SU<$ztmrHU>2@{`Vnk4r@|2YU9u%$yVOxo!W#*(kM{ zBGfULee-cRfv0A6EC$bD824gwdX~#w5@A{^XmjDDKWFODwux<=>8_{vhn<(^8Gc9E z?4qJUDv(lk1MaO7S!IkgMq-g>NUW?WQWkHDMBz+yP~)2!Lo=%FkoN_U?c0i^7^Gqy zRGl#m9wKIDMMCUMdClzCIPZtGxcC%qqA{uG=cU3CH3*5>01z-dr^-r;+DbcfEb^iX zTqUr};sD#rC42q;fF35H)({VqEumsjfl)hNKY7VEtcmAw9m}>im2eWy*za=@X(kXS zar6uET!?SVVW}%_frs=T8#eMuALWuREK5)3_rCKneO5cwGm%BJOPWco&m(Cb9m;TG zTZXp(n0sOK&<~IDAB!}PJ+Qh|1fm8IH)5oQCc0;6fft04z*5Gzu8F_~ele>u+tc$KLdYdZDWyI_+z-m+Q;LZ`J#!;ImigANp zKeb#)yCYc6IjaqB3yFrm-+W`6R0c7`+=pE#m1qpc!x=B*xr=qjFdc!hgUmB zD&Ia2kX4udGEx-$dBM#h-68Uc{&%ZD`$*-C;>>`LNoql_4Hb5E2*C=_3etOT9e?=F zD|BIByXmK2t+R82t>>tFqV55S!2-@> zORTQD6N-*KAzqkeBK{Onq{1Z?*oeVyz15Z=;>`unUjMOV0P(zo!x_=uNqqd8C$ ztNS4TOORJUo=5(Hsd)h!3@g&FAzI8Q#_nzZ19lbUg7e(hB;u~u1Zc@}X06{=Oo zhxhH4_`5!eO_@IWa*sBvoVh!?E$EY2|67FRwdkx*jQ%ywh`I+kFWye5C2u`iM#Fl@ zxv{JyzbNg)lw#NqS4}^yv%Rr$^1?|rPYGe8PcI4CjM>RnjCzVxR9g;z-1#~iVfW)b zpZ;{jLn^aF1w^Cp8!@7UW?y;$qGzsd7DI->zLz$CWqbd-{PedT)u3%^%uf3sT(`6I zb)+<&yDg1jIi=12@zDaBSoh2dE@n7qr%Jtw*Mh|u8|mJtUz z^j<2BY=Gp#b?WRD`91S~s+ULEi>m{W^9N=V@NBmU7v${Jy8Y zIHp=nhEIq2VcVuk=8&6W;HA22NCw=Ky$Q~%QIQ$LuOXPdDY1P-kCsO4C2#9`ey~)4 zAa(-_7z{Oj(%3Cz_za{NIyOH^^MUQkxu3d)rrMQ7oC;ZLa0(~)zZXuLiTZwY@U0rN zIbk9pi{pyPdL&f!5}xI;Gd>`_=vRcE$YoBc3GQRa;MOml&Q()K2^I=F=(#obWV9Zb ztk_(lVD>QcNX+5)MAG=39)&GUJ02e-6g&IopY=EQm-I2{O8ss!G#_A{60FTig7ue@ z1w2ODJR7%(Yn?Cq$T}MX-!E>a-Sh!I@QFb@E;zMhug?XKA5Zj=taQKBxH86LRo*j) zKI%04HqJRcHWn}u>fn8hLuY`!B~f^rU|7HfyQR_R^LLkgPb+lY;=d*^GkP@Rc|fzh zQxXq(Q|zEb(ZmrMn?8Q|SIL;>@ykC?KdyF2uMIbtncw%kYvDJokc7(u6;)C}S|~hb z1FB3=;G`91RzGf(D>5fStK8*+mcY+?6Z<_(8+xqZ;k8_0y)VLICc~nt#D?f+%II8) zi&JYcV>h0c!j*0n%3V1&4ZTmjmMyiOO-SZx96cWrep0M+46c)S`_y@b7sWa~bvq9nLKoV} z__s=0JANrt_dj#_M2^elMb_^LV;MpN!Zl{k(~fWBO^Z3ECs>gsHS5EKucyi4Br|=U z-TG;7h)ILSC9lu3Y&F)LD%-gixhJzq(yW{|v&2T04e*Rw3`tnsDVYlpt!7~KLs?In zv1*=?xZQAqUuq!HCi}ggHEhW%c=LA0NV)0CSdEcm4YUTLznw!ox|D6F-@Vhs2b`9h z>JEtwIo~O|N!3iQMEQ>6&%?EECo=bl%MqTJ=b`FZjNa;-zOt-t?z*J2A@a_VK zzACt1Gut-i`<;C|&+6QXNvv;iOIE+WkXg#j&WR^_dDo8~ z$}VQ5}s%)4ynC;==YABA~nn^KVlyGW|_w4-aFz`*oAl_ zZq`r5TB*QGxs5!S|9newpgXzy2l(jl;MF}Ab}_Fz;zLSyf{d9GUYJJpEji3mrw{su zeoI3#*=ZZcUzpB!d{+OCF7Eqf)2}mZDLK=nqo?0*+N`$Sg$ip_a2p!4f#sQ@ajPeo zy`Qx(9I|0jpjL<>R`JHKQ|fJuN$GC-qq$#FY}49EvmylV)|xl!QAVlmU%a|I%PW58 zGj9%9GB#%~6RXhcuzTuNfz?F}G|I)LAGsAjQG7hYIzRB!SbvX1=#&!rzDwm@08X+ok!*GdK6McX=7>>7plxQ{~u@TxL&FeM=HUQ@FOMyJ}lvnF?Flg4A#i;t08!cN={iXxl3z3JS|i z>)1$}CO02^AwP!I3yY_DFudpN_k}H27Z1r&(Hj*ciNbB6`my$q}GI#^1tSHbF|={p zf_UD3M-pR_v@agef4_;IYwqd%sgGVW0l}Lbnvl!iJpcQd6iW??;Whx2ap&4ufOzUN zSz%0c-CZ{n_n)~GMOiu`c7>if1KBD*z#E=;R^7m1I)nIH%zS^DvZ%=I8eIZ=v9T=O zem4VVOm>0(Na|EQN%3|dp{$qwLPE=swI^1CzfZ)2W9Q?%cCe)qXwGX=5-?FR#q&!|iv74)?c@i6$CD{E?o zB(2jbX^AmvXN385XdZUF31fU4Vhuj3<;gKz+a0wxd@l>l>zgbY>q!G6S?4cN_zhT& zW&(i|ES~Aecf-VW?80VC+i%@pSSn_CuAMiv87ol3YuFaV?WR0cp_NB|_3gKsU=MHHO+MquC)<9P9|6ugi+)W=5Q&Bun-{o++V;0sY*Zh^m zAWZf>6+x*$I}Bz6a9-L}->(VI{E`eicnLGdFs?X7sR_KhRF#yhsm1E>4fW=9jNS=r z`0f1X4{yhHs6@UlFX@Qv(s|h;AKji1dnZXc^0&q6};r5&x6437{4 zN7?rN=#3x=k2UeXaO7!r`BBa(2jc$ea(g4T}Ru+ekbLVl98j_8#;zFwyy^TDx$a+SZC&0KCB6YuiN>n+|oQ{k%L-A%Zbb2rV}jdGr=t{|f8kfdGO;Y3&WGd4_1twy^*XRGe zvjMDu>A}aaE{QuTg`f66{(ALlZ2eE~BV2cR43G|E5(w+3`t`kHFAE#TZFkO!Rj9o> z_aUppv}u;NlDkW~T)T_owls{+giFaP6T!2qx9i)ae36gd#Zb$wx4C!Jh3nFUAD&|n zD#XBzCsZEQI-yNz!$kYZ?!H@N{C3BjDyYjD7B4Hv{ibA{**=`CU)U{ZXf(O0sjW%( z(m~z%OPLvZ8cFHzgDcd&oo+}{RMtkdhD_sQJoD&xPY4x?bv)#jzuvsL&L!*Ffk}r4 zLQxWDEtx8dggn03a-|jsmW}ZbAYZj3&75*yU#IVV>qH9s_+|`O{vi5;N%Fx)pO0ys zN4Gw4`@Ms=ZG^~^6rz2lf|JEvt|7W1=a5{}p|4y7)`n=}g)Av%qTB8(LGdvS-8q$o zj*ps!Y`%YGGoLyUMP(wWVAze|rez-#(#Wk2zS|4Fr8T88G8IUM!fXU^oYQBe69$W_ z?7aAf=A9G1=&4BZ(ua@>UR*?_Nb}I@uv{gXBeWNIb_(-XeZF0s7O!coY#wB4!PYtw zRr;u1c<_?cr!ZZ~)56BF8QT-3NnW=@UgBA9+(N;;h0kkW<;B*YD%gZ#q#_L!^n?M+)eRUd!3#i?_rGAl@DQ(;GXntYFWzG* z9^j_+rI@BD*q8^jM@5?7nx@R(Ont#AJARZD`H-=-g!6;xX`P}v3*l!ondtJ{I9IwB z)4~Vi1>?T0Vfz{juP`;wS5=h1X%RdpWW^(rtI>dNdN{c~?eWL!F19Dr+#AjXMr7}O zF5PTJZ@Obh?!%ec++)%=wnX2}OgjS;w20w7srvbnc3kj}ZRS?+Cm#fKea^-hCx!%s zzR+PN9G40(+4&8#2gNn{Vu$9SYN%LJ2BE2j@@;soyW-OC{a^PnG;KjOM1jm_JVLaFVn;#@V0s~ODC#X_$zBpBR`|w zf+F|5b7EYADJ{dSEV-xbh4Rm4<6mulo43!ChkC#*XCY*O^^P*%r@A z3{Ez^M)AAe?j!ZGPaiB-%2aJFwR|icr&JX;1quL%Ac}(^jQl##4k2`pSs& zAxZ_leaF-CnYfpJQPFAm<;C`$d?R@{P=DevZ_n8A&eksv%pID}_j#${pN$UJMUVYF zTRt%{zesEEU>i+Eht#m-M&9rnCg`%NgN6HS_5;462u#C-;=QXiZZr4F)EWFk7suf~ zii+L7qB0cc$+Pi$&kKw-RKkg)Pvn+b?csHWiGC+W4{d&Qlt)3EMd;)Ttmq zxecs)p#6OzNM?=W`(Ja6-Wy;LhV|cT()A$@4<#Qm9aVWF&k=t)R9y9Y<{|jN)Funf zQ4`&cCBwVLnr2;>Ka?2KiDTlw`wJ9h^#pa>?D0(IedQBOlP&mRJ%a!euEU8NV$iUf z12FhrDN?)PoK78Evq{7++vx~wn$bm$HijePo5EhlA%y4h5d0ChnAjwL>QMSMSX_BT zxf;u*K3&b`@g$3J4OkoGSE{yuEBoLbE>yE`C|&?Yg+VI#4z-b&=}PDBgC!8y>ogBr zQYq+{$UU+D%HgPwOXVW^wM;>qiDx6byLVj8YCP?5eyr)KnE8=8?j{v6*_KQP*203P zg4&5ePIBAqgl_JLAFY^N;HXV<3GRNRRwL#Xa=g^ceekj`*F!-Z{X2{zw*vcgG9hQ5?yUhBQu9VcZ&(q&MD)XC)R8-&*8nvOlLIhUu7HxVKY^oGJiwNAG zx2N;>z0R{wEQIdcoxYhOocYBxxk&hpnQ~C3Om@v+U&mqWMA}Vha>$*kIz(AgLySvw z_vkMDKbYB0QG~081kWTr`Z7;SV;;qp1D&dJ3xYHhfDb>zmO)g%~ zF7X1~J@69aRr-ga7PG=fY1=}#Y_gCO=$q{Ol&Z9U_Y(>ay0-+^zKxMfbC=?L{N-m{ z?`^P8(VV@HI7$!G)+zi#8qv^j`L=I%zNnW{IsVYQ;v3k!zVLg~MQ!u6_4)y-_}45~ zvMH8RBoJzjWYTKQ(8NT;Vy_12u%m^)nWGmO3e|pOin^PlAGjG|;J8``d$3>#=pGFc znds;4L-F>5qaf^J3P+gvJ5iQ%H6R)Qx)?cnf{X{XwL`DmueOrqAdNzzQE+K2c~XaIUZsslvA(2?TjzMKjngG54>3;zD89fTnu=8Qpx2FQSbLK;Y)(DVWc4N(5Z zzj-5nnqUYGcNZ5D$i9HmEo|Y?h6%EdBN1eR0Qr2NvUvfz>-Ob&2j3`dy2p{eq^Gz9|=wXJssf)Cpdx#M>qolfHWKd@?zZJ2zQVQ0{HZRBRpXUFTg5bAC3ak zg8&SIDFxE`gLA;ZQ8*$HjyMMPPJ`qRgr&PP#SLf?LbreCkt|Q_pUq?c8qQx+zdTR> zhj|7vdN{h0fu=V!-n1NI;a(1|n}`b4en=%Yhgo)9DqVfCj9@@_yzkULS; z%hiJf7)O{q6(1n`paAQ|!*Rf><-kKK3mybnIlurC)S)AE z9c*C7gN!Fxa6dFA0mMAK+)5jCTowi+3xkdj6Od3)!2^SX?!^JYSH=K+0o~;~c69B-%gLXU$4Fp`Fiv@IXz=L22(0y1q4jkoxvY_K~K|}rU%SwQb5dGyz zgA@pO01AacG;!dbWu^RK4mvK|4@3zIogg642goM=lx164wF$@`{w@$- z%jaMW$Tt2INOC+92j&p)3fTaaLXZP<0@(!!UVoQAHu2{ODFw_sZ21V)|0>IKy>cDu z^G|{70-9Cm_;2#a->V{35R8_6A@J8giObPV6Am2ba!m`qyz}1_a^GNPm)rE=;8WIe z%@7VI8Ps6Fm7`Z9qX`^L`0`bAIC?o&L7pAGO2ZM3UZvp#^t#fI2w&dj2S#xQRlo891GT9#|>;Hg=uc>w!a zsd>Ty99+KY1qaS^r3M8v?5a@S@a3;ZKzJw+VpnPTK(TVU^S_Z4e2V_JWi_BKQ-=K$ zR9DIS0W)3cO9oHUY7L5{s|>8fY22!~6gXINfRSM<0T`Ic${kQl2A*n}rau&gS8D%; zPVkL-x#=jF`ITBAutTT@1ECN+Rm)eGqdb0v^1u0|5bZyaPzyYND>VPT4srD7x+>~# z{P9o5DMDQpbV3jv;V395Aym=-$8G>Ptake=d_g2vN&HI}l6PGrB>K8WNTNR+uEg?{ zOj?A#yE6ocz(Wbizno_34`S~!ldCCM!14ckvpPr>g<#(vfX{q9yeS^;PVfMXjI0b& z3hqXs_>koh2+w~m$$0y@N`Pm<&)eCb2r^IqzRSni1@7cXJOZe!jsSWt5AW{nr2%E> z?$MA(BhgqS1_iceW6>zIG!iF{M2buNrI$cIk_!xY3M>o+wbdVpSvc?Svr!`8MTk-#+8^g}|A#F{n~^uVlbgG^{`8}MWQZ3AZj4z6v3-ml%(ZFFKXTPQ&R(9vcsyucu4E06e&6JOUnMaQwFoKqY7$1O5Je7QhzU_5HA5ZMC+a?0O%8Lap~D zDC~NFfWiYjv}Qb%-1_Gd1!1i<{m>{VibKDDc}8Qv_SQ9R;C0$xHbn7rboU_nE&ta9 yGxs2{0tHJyQ*Up`nJ>pST`w1J0DYEEKnx*M9Q`QE?h}Q>qLDB$G0pv2u>TJ!#>K7x literal 0 HcmV?d00001 diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb new file mode 100644 index 0000000..41a4801 --- /dev/null +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb @@ -0,0 +1,850 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6f678e62-7bcb-4405-86ae-dce94f494303", + "metadata": { + "id": "6f678e62-7bcb-4405-86ae-dce94f494303" + }, + "source": [ + "# Efficient Multi-Head Attention Implementations" + ] + }, + { + "cell_type": "markdown", + "id": "b742938a-4bfc-4527-a1f1-d5963508967d", + "metadata": { + "id": "b742938a-4bfc-4527-a1f1-d5963508967d" + }, + "source": [ + "This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "outputId": "7d088260-3fa1-44f2-bd65-2a46e289f9d4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch version: 2.1.0\n", + "Running on cpu\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "torch.manual_seed(123)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "print(f\"Running on {device}\")\n", + "\n", + "batch_size = 8\n", + "context_len = 1024\n", + "embed_dim = 768\n", + "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" + ] + }, + { + "cell_type": "markdown", + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", + "metadata": { + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" + }, + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "outputId": "f8a33752-2cd6-4101-8feb-9d1699984719" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", + "\n", + "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim//12,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_ch03_wrapper(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "21930804-b327-40b1-8e63-94dcad39ce7b", + "metadata": { + "id": "21930804-b327-40b1-8e63-94dcad39ce7b" + }, + "source": [ + "## 2) The multi-head attention class from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "outputId": "b704a040-3547-422c-ecda-df9982a2da35" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttention as Ch03_MHA\n", + "\n", + "mha_ch03 = Ch03_MHA(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_ch03(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", + "metadata": { + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" + }, + "source": [ + "## 3) An alternative multi-head attention with combined weights" + ] + }, + { + "cell_type": "markdown", + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", + "metadata": { + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" + }, + "source": [ + "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", + "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", + "\n", + " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + "\n", + "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", + "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "outputId": "5d948671-176f-4633-bede-97767e36becc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MultiHeadAttentionCombinedQKV(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", + " attn_scores = queries @ keys.transpose(-2, -1)\n", + " attn_scores = attn_scores.masked_fill(\n", + " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", + " )\n", + "\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", + " attn_weights = self.dropout(attn_weights)\n", + "\n", + " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", + " context_vec = attn_weights @ values\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", + " context_vec = context_vec.transpose(1, 2)\n", + "\n", + " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", + " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", + "\n", + " context_vec = self.proj(context_vec)\n", + "\n", + " return context_vec\n", + "\n", + "\n", + "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_combined_qkv(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", + "metadata": { + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" + }, + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention" + ] + }, + { + "cell_type": "markdown", + "id": "f78e346f-3b85-44e6-9feb-f01131381148", + "metadata": { + "id": "f78e346f-3b85-44e6-9feb-f01131381148" + }, + "source": [ + "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", + "metadata": { + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" + }, + "outputs": [], + "source": [ + "class MHAPyTorchScaledDotProduct(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + " self.d_out = d_out\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = dropout\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " use_dropout = 0. if not self.training else self.dropout\n", + " context_vec = nn.functional.scaled_dot_product_attention(\n", + " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "outputId": "af9e4855-7f20-4d61-8532-4827df8dfb30" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_scaled(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "351c318f-4835-4d74-8d58-a070222447c4", + "metadata": { + "id": "351c318f-4835-4d74-8d58-a070222447c4" + }, + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention" + ] + }, + { + "cell_type": "markdown", + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", + "metadata": { + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" + }, + "source": [ + "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "outputId": "2a085df8-0445-4818-9978-6dc74469f568" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MHAPyTorchClass(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n", + " super().__init__()\n", + "\n", + " self.block_size = block_size\n", + " self.multihead_attn = nn.MultiheadAttention(\n", + " embed_dim=d_out,\n", + " num_heads=num_heads,\n", + " dropout=dropout,\n", + " bias=qkv_bias,\n", + " add_bias_kv=qkv_bias,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.need_weights = need_weights\n", + " self.proj = nn.Linear(d_out, d_out)\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, _ = x.shape\n", + "\n", + " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", + " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", + " if self.block_size >= num_tokens:\n", + " attn_mask = self.mask[:num_tokens, :num_tokens]\n", + " else:\n", + " attn_mask = self.mask[:self.block_size, :self.block_size]\n", + "\n", + " # attn_mask broadcasting will handle batch_size dimension implicitly\n", + " attn_output, _ = self.multihead_attn(\n", + " x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n", + " )\n", + "\n", + " output = self.proj(attn_output)\n", + "\n", + " return output\n", + "\n", + "\n", + "mha_pytorch_class_default = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class_default(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4", + "metadata": { + "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4" + }, + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`" + ] + }, + { + "cell_type": "markdown", + "id": "d2164859-31a0-4537-b4fb-27d57675ba77", + "metadata": { + "id": "d2164859-31a0-4537-b4fb-27d57675ba77" + }, + "source": [ + "- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n", + "\n", + "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n", + " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n", + " and achieve the best performance for MHA.\n", + " Default: ``True``." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", + "outputId": "234771f4-8a53-4478-8a9b-cf19f79a5e07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch_class_noweights = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False,\n", + " need_weights=False # NEW!\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class_noweights(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "8877de71-f84f-4f6d-bc87-7552013b6301", + "metadata": { + "id": "8877de71-f84f-4f6d-bc87-7552013b6301" + }, + "source": [ + "## Quick speed comparison (M3 Macbook Air CPU)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "194 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3\n", + "%timeit mha_ch03_wrapper(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "198 ms ± 4.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 2) The multi-head attention class from chapter 3\n", + "%timeit mha_ch03(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "234 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 3) An alternative multi-head attention with combined weights\n", + "%timeit mha_combined_qkv(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "outputId": "80c6e314-0771-470e-b090-628984ce2d85" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "71.7 ms ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention\n", + "%timeit mha_pytorch_scaled(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "211 ms ± 5.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", + "%timeit mha_pytorch_class_default(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", + "metadata": { + "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", + "outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "207 ms ± 18.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", + "%timeit mha_pytorch_class_noweights(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "dabc6575-0316-4640-a729-e616d5c17b73", + "metadata": { + "id": "dabc6575-0316-4640-a729-e616d5c17b73" + }, + "source": [ + "## Speed comparison (Nvidia A100 GPU) with warmup" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000", + "metadata": { + "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000" + }, + "outputs": [], + "source": [ + "# CUDA benchmark code shared by Andrei Aksionov\n", + "# and based on code from\n", + "# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n", + "\n", + "import time\n", + "\n", + "def time_pytorch_function(func, *input, num_repeats = 100):\n", + " # CUDA IS ASYNC so can't use python time module\n", + " #start = torch.cuda.Event(enable_timing=True)\n", + " #end = torch.cuda.Event(enable_timing=True)\n", + " start = time.time()\n", + " # Warmup\n", + " #for _ in range(5):\n", + " # func(*input)\n", + " #torch.cuda.synchronize()\n", + "\n", + " #start.record()\n", + " for _ in range(num_repeats):\n", + " func(*input)\n", + " #torch.cuda.synchronize()\n", + " #end.record()\n", + " #torch.cuda.synchronize()\n", + " return (time.time()-start) / num_repeats" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "CDJAPZaszaqx", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 489 + }, + "id": "CDJAPZaszaqx", + "outputId": "f23e9b83-7fd6-4011-9434-0e6934cf762a" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "#embeddings_cuda = embeddings.to(torch.device(\"cuda\"))\n", + "\n", + "functions = {\n", + " \"1) MHA wrapper class\": mha_ch03_wrapper,\n", + " \"2) MHA Ch03\": mha_ch03,\n", + " \"3) MHA with combined QKV weights\": mha_combined_qkv,\n", + " \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n", + " \"5) PyTorch MHA class defaults\": mha_pytorch_class_default,\n", + " \"6) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n", + "}\n", + "execution_times = [time_pytorch_function(fn, embeddings) for name,fn in functions.items()]\n", + "\n", + "\n", + "# Plotting\n", + "\n", + "# Customize further for dark mode aesthetics\n", + "plt.rcParams['figure.facecolor'] = '#121212' # Dark figure background\n", + "plt.rcParams['axes.facecolor'] = '#121212' # Dark axes background\n", + "plt.rcParams['axes.edgecolor'] = 'white' # White axes border\n", + "plt.rcParams['axes.labelcolor'] = 'white' # White labels\n", + "plt.rcParams['text.color'] = 'white' # White text\n", + "plt.rcParams['xtick.color'] = 'white' # White x ticks\n", + "plt.rcParams['ytick.color'] = 'white' # White y ticks\n", + "plt.rcParams['grid.color'] = '#444444' # Lighter grid lines for contrast\n", + "plt.rcParams['lines.linewidth'] = 2 # Thicker plot lines for visibility\n", + "plt.rcParams['lines.markersize'] = 8 # Larger markers for visibility\n", + "\n", + "fig, ax = plt.subplots()\n", + "bars = plt.bar(functions.keys(), execution_times)\n", + "\n", + "plt.ylabel('Execution time (ms)')\n", + "plt.xticks(rotation=45, ha=\"right\")\n", + "\n", + "# Calculate new ylim with a margin\n", + "max_execution_time = max(execution_times)\n", + "upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n", + "\n", + "plt.ylim(0, upper_ylim) # Setting new ylim\n", + "\n", + "# Annotate bars with execution times\n", + "for bar in bars:\n", + " yval = bar.get_height()\n", + " plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha='center', va='bottom')\n", + "\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(\"2.pdf\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3e1137b-9acc-4cc5-bcbf-0e8533839f06", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/02_hparam_tuning/hparam_search.py b/ch05/02_bonus_hparam_tuning/hparam_search.py similarity index 100% rename from ch05/02_hparam_tuning/hparam_search.py rename to ch05/02_bonus_hparam_tuning/hparam_search.py diff --git a/ch05/02_hparam_tuning/previous_chapters.py b/ch05/02_bonus_hparam_tuning/previous_chapters.py similarity index 100% rename from ch05/02_hparam_tuning/previous_chapters.py rename to ch05/02_bonus_hparam_tuning/previous_chapters.py diff --git a/ch05/02_hparam_tuning/the-verdict.txt b/ch05/02_bonus_hparam_tuning/the-verdict.txt similarity index 100% rename from ch05/02_hparam_tuning/the-verdict.txt rename to ch05/02_bonus_hparam_tuning/the-verdict.txt diff --git a/ch05/03_bonus_pretraining_on_gutenberg/README.md b/ch05/03_bonus_pretraining_on_gutenberg/README.md new file mode 100644 index 0000000..1c0f94d --- /dev/null +++ b/ch05/03_bonus_pretraining_on_gutenberg/README.md @@ -0,0 +1,121 @@ +# Pretraining GPT on the Project Gutenberg Dataset + +The code in this directory contains code for training a small GPT model on the free books provided by Project Gutenberg. + +As the Project Gutenberg website states, "the vast majority of Project Gutenberg eBooks are in the public domain in the US." + +Please read the [Project Gutenberg Permissions, Licensing and other Common Requests](https://www.gutenberg.org/policy/permission.html) page for more information about using the resources provided by Project Gutenberg. + +  +## How to use this code + +  + +### 1) Download the dataset + +As of this writing, this will require approximately 50 GB of disk space, but it may be more depending on how much Project Gutenberg grew since then. + +Follow these steps to download the dataset: + + +1. `git clone https://github.com/pgcorpus/gutenberg.git` + +2. `cd gutenberg` + +3. `pip install -r requirements.txt` + +4. `python get_data.py` + +5. `cd ..` + +  +### 2) Prepare the dataset + +Next, run the `prepare_dataset.py` script, which concatenates the (as of this writing, 60,173) text files into fewer larger files so that they can be more efficiently transferred and accessed: + +``` +prepare_dataset.py \ + --data_dir "gutenberg/data" \ + --max_size_mb 500 \ + --output_dir "gutenberg_preprocessed" +``` + +> [!TIP] +> Note that the produced files are stored in plaintext format and are not pre-tokenized for simplicity. However, you may want to update the codes to store the dataset in a pre-tokenized form to save computation time if you are planning to use the dataset more often or train for multiple epochs. See the *Design Decisions and Improvements* at the bottom of this page for more information. + +> [!TIP] +> You can choose smaller file sizes, for example, 50 MB. This will result in more files but might be useful for quicker pretraining runs on a small number of files for testing purposes. + + +  +### 3) Run the pretraining script + +You can run the pretraining script as follows. Note that the additional command line arguments are shown with the default values for illustration purposes: + +```bash +pretraining_simple.py \ + --data_dir "gutenberg_preprocessed" \ + --n_epochs 1 \ + --batch_size 4 \ + --output_dir model_checkpoints +``` + +The output will be formatted in the following way: + +``` +Total files: 3 +Tokenizing file 1 of 3: data_small/combined_1.txt +Training ... +Ep 1 (Step 0): Train loss 9.694, Val loss 9.724 +Ep 1 (Step 100): Train loss 6.672, Val loss 6.683 +Ep 1 (Step 200): Train loss 6.543, Val loss 6.434 +Ep 1 (Step 300): Train loss 5.772, Val loss 6.313 +Ep 1 (Step 400): Train loss 5.547, Val loss 6.249 +Ep 1 (Step 500): Train loss 6.182, Val loss 6.155 +Ep 1 (Step 600): Train loss 5.742, Val loss 6.122 +Ep 1 (Step 700): Train loss 6.309, Val loss 5.984 +Ep 1 (Step 800): Train loss 5.435, Val loss 5.975 +Ep 1 (Step 900): Train loss 5.582, Val loss 5.935 +... +Ep 1 (Step 31900): Train loss 3.664, Val loss 3.946 +Ep 1 (Step 32000): Train loss 3.493, Val loss 3.939 +Ep 1 (Step 32100): Train loss 3.940, Val loss 3.961 +Saved model_checkpoints/model_pg_32188.pth +Book processed 3h 46m 55s +Total time elapsed 3h 46m 55s +ETA for remaining books: 7h 33m 50s +Tokenizing file 2 of 3: data_small/combined_2.txt +Training ... +Ep 1 (Step 32200): Train loss 2.982, Val loss 4.094 +Ep 1 (Step 32300): Train loss 3.920, Val loss 4.097 +... +``` + + +  +> [!TIP] +> In practice, if you are using macOS or Linux, I recommend using the `tee` command to save the log outputs to a `log.txt` file in addition to printing them on the terminal: + +```bash +python -u pretraining_simple.py | tee log.txt +``` + +  +> [!WARNING] +> Note that training on 1 of the ~500 Mb text files in the `gutenberg_preprocessed` folder will take approximately 4 hours on a V100 GPU. +> The folder contains 47 files and will take approximately 200 hours (more than 1 week) to complete. You may want to run it on a smaller number of files. + + +  +## Design Decisions and Improvements + +Note that this code focuses on keeping things simple and minimal for educational purposes. The code could be improved in the following ways to improve modeling performance and training efficiency: + +1. Modify the `prepare_dataset.py` script to strip the Gutenberg boilerplate text from each book file. +2. Update the data preparation and loading utilities to pre-tokenize the dataset and save it in a tokenized form so that it doesn't have to be re-tokenized each time when calling the pretraining script. +3. Update the `train_model_simple` script by adding the features introduced in [Appendix D: Adding Bells and Whistles to the Training Loop](../../appendix-D/01_main-chapter-code/appendix-D.ipynb), namely, cosine decay, linear warmup, and gradient clipping. +4. Update the pretraining script to save the optimizer state (see section *5.4 Loading and saving weights in PyTorch* in chapter 5; [ch05.ipynb](../../ch05/01_main-chapter-code/ch05.ipynb)) and add the option to load an existing model and optimizer checkpoint and continue training if the training run was interrupted. +5. Add a more advanced logger (for example, Weights and Biases) to view the loss and validation curves live +6. Add distributed data parallelism (DDP) and train the model on multiple GPUs (see section *A.9.3 Training with multiple GPUs* in appendix A; [DDP-script.py](../../appendix-A/03_main-chapter-code/DDP-script.py)). +7. Swap the from scratch `MultiheadAttention` class in the `previous_chapter.py` script with the efficient `MHAPyTorchScaledDotProduct` class implemented in the [Efficient Multi-Head Attention Implementations](../../ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb) bonus section, which uses Flash Attention via PyTorch's `nn.functional.scaled_dot_product_attention` function. + diff --git a/ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py b/ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py new file mode 100644 index 0000000..5548b58 --- /dev/null +++ b/ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +Script that processes the Project Gutenberg files into fewer larger files. +""" + +import argparse +import os + + +def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"): + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + current_content = [] + current_size = 0 + file_counter = 1 + + for file_path in file_paths: + try: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + except UnicodeDecodeError: + # Attempt to read the file with a fallback encoding + print(f"Warning: UnicodeDecodeError encountered. Trying fallback encoding for {file_path}") + with open(file_path, "r", encoding=fallback_encoding) as file: + content = file.read() + + estimated_size = len(content.encode("utf-8")) + + if current_size + estimated_size > max_size_mb * 1024 * 1024: + target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt") + with open(target_file_path, "w", encoding="utf-8") as target_file: + target_file.write(separator.join(current_content)) + file_counter += 1 + current_content = [content] + current_size = estimated_size + else: + current_content.append(content) + current_size += estimated_size + + if current_content: + target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt") + with open(target_file_path, "w", encoding="utf-8") as target_file: + target_file.write(separator.join(current_content)) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="GPT Model Training Configuration") + + parser.add_argument("--data_dir", type=str, default="gutenberg/data", + help="Directory containing the downloaded raw training data") + parser.add_argument("--max_size_mb", type=int, default=500, + help="The maximum file size for each concatenated file in megabytes") + parser.add_argument("--output_dir", type=str, default="gutenberg_preprocessed", + help="Directory where the preprocessed data will be saved") + + args = parser.parse_args() + + all_files = [os.path.join(path, name) for path, subdirs, files in os.walk(args.data_dir) + for name in files if name.endswith((".txt", ".txt.utf8")) and "raw" not in path] + + target_dir = "path_to_your_large_files" + print(f"{len(all_files)} files to process.") + + combine_files(all_files, args.output_dir) \ No newline at end of file diff --git a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py new file mode 100644 index 0000000..8a738d5 --- /dev/null +++ b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +""" +Script for pretraining a small GPT-2 124M parameter model +on books from Project Gutenberg. + +Before running this script, make sure you downloaded and +processed the dataset as described in the README.md. +""" + +import argparse +import os +from pathlib import Path +import time +import torch +from previous_chapters import ( + create_dataloader_v1, + GPTModel, + generate_and_print_sample, + calc_loss_batch, + evaluate_model, + plot_losses +) + + +def read_text_file(file_path): + with open(file_path, "r", encoding="utf-8") as file: + text_data = file.read() + return text_data + + +def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride): + split_idx = int(train_ratio * len(text_data)) + train_loader = create_dataloader_v1( + text_data[:split_idx], + batch_size=batch_size, + max_length=max_length, + stride=stride, + drop_last=True, + shuffle=True + ) + val_loader = create_dataloader_v1( + text_data[split_idx:], + batch_size=batch_size, + max_length=max_length, + stride=stride, + drop_last=False, + shuffle=False + ) + return train_loader, val_loader + + +def convert_time(seconds): + hours, rem = divmod(seconds, 3600) + minutes, seconds = divmod(rem, 60) + return int(hours), int(minutes), int(seconds) + + +def print_eta(start_time, book_start_time, index, total_files): + book_end_time = time.time() # End time of processing this book + elapsed_time = book_end_time - book_start_time + total_elapsed_time = book_end_time - start_time + books_remaining = total_files - index + average_time_per_book = total_elapsed_time / index + eta = average_time_per_book * books_remaining + + book_h, book_m, book_s = convert_time(elapsed_time) + total_h, total_m, total_s = convert_time(total_elapsed_time) + eta_h, eta_m, eta_s = convert_time(eta) + + print(f"Book processed {book_h}h {book_m}m {book_s}s" + f"\nTotal time elapsed {total_h}h {total_m}m {total_s}s" + f"\nETA for remaining books: {eta_h}h {eta_m}m {eta_s}s") + + +def train_model_simple(model, optimizer, device, n_epochs, + eval_freq, eval_iter, print_sample_iter, start_context, + output_dir, save_ckpt_freq, + batch_size=1024, train_ratio=0.90): + + train_losses, val_losses, track_tokens_seen = [], [], [] + tokens_seen = 0 + global_step = -1 + start_time = time.time() + + try: + for epoch in range(n_epochs): + + # Iterate over the books in the training corpus + for index, file_path in enumerate(all_files, 1): + book_start_time = time.time() + text_data = read_text_file(file_path) + " <|endoftext|> " + print(f"Tokenizing file {index} of {total_files}: {file_path}") + + # Initialize new data loaders for each book + train_loader, val_loader = create_dataloaders( + text_data, + train_ratio=train_ratio, + batch_size=batch_size, + max_length=GPT_CONFIG_124M["ctx_len"], + stride=GPT_CONFIG_124M["ctx_len"] + ) + print(f"Training ...") + model.train() + for input_batch, target_batch in train_loader: + optimizer.zero_grad() + loss = calc_loss_batch(input_batch, target_batch, model, device) + loss.backward() + optimizer.step() + tokens_seen += input_batch.numel() + global_step += 1 + + # Optional evaluation step + if global_step % eval_freq == 0: + train_loss, val_loss = evaluate_model( + model, train_loader, val_loader, device, eval_iter) + train_losses.append(train_loss) + val_losses.append(val_loss) + track_tokens_seen.append(tokens_seen) + print(f"Ep {epoch+1} (Step {global_step}): " + f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") + + # Generate text passage + if index % print_sample_iter == 0: + generate_and_print_sample( + model, train_loader.dataset.tokenizer, device, start_context + ) + + if global_step % save_ckpt_freq: + file_name = output_dir / f"model_pg_{global_step}.pth" + torch.save(model.state_dict(), file_name) + print(f"Saved {file_name}") + + print_eta(start_time, book_start_time, index, total_files) + + except KeyboardInterrupt: + file_name = output_dir / f"model_pg_{global_step}_interrupted.pth" + torch.save(model.state_dict(), file_name) + print(f"Saved {file_name}") + + return train_losses, val_losses, tokens_seen + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='GPT Model Training Configuration') + + parser.add_argument('--data_dir', type=str, default='gutenberg/data', + help='Directory containing the training data') + parser.add_argument('--output_dir', type=str, default='model_checkpoints', + help='Directory where the model checkpoints will be saved') + parser.add_argument('--n_epochs', type=int, default=1, + help='Number of epochs to train the model') + parser.add_argument('--print_sample_iter', type=int, default=500, + help='Iterations between printing sample outputs') + parser.add_argument('--eval_freq', type=int, default=100, + help='Frequency of evaluations during training') + parser.add_argument('--save_ckpt_freq', type=int, default=100_000, + help='Frequency of saving model checkpoints during training') + parser.add_argument('--lr', type=float, default=5e-4, + help='Learning rate for the optimizer') + parser.add_argument('--batch_size', type=int, default=4, + help='Batch size for training') + + args = parser.parse_args() + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "ctx_len": 1024, # Context length + "emb_dim": 768, # Embedding dimension + "n_heads": 12, # Number of attention heads + "n_layers": 12, # Number of layers + "drop_rate": 0.1, # Dropout rate + "qkv_bias": False # Query-key-value bias + } + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1) + + data_dir = args.data_dir + all_files = [os.path.join(path, name) for path, subdirs, files + in os.walk(data_dir) for name in files if name.endswith((".txt"))] + total_files = len(all_files) + + if total_files == 0: + print("No training text files found. Make sure you " + "selected the correct input directory") + quit() + print("Total files:", total_files) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + train_losses, val_losses, tokens_seen = train_model_simple( + model, optimizer, device, + batch_size=args.batch_size, + n_epochs=args.n_epochs, + eval_freq=args.eval_freq, + eval_iter=1, + print_sample_iter=args.print_sample_iter, + output_dir=output_dir, + save_ckpt_freq=args.save_ckpt_freq, + start_context="Every effort moves you", + ) + + epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses)) + plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) + + torch.save(model.state_dict(), output_dir / "model_pg_final.pth") + print(f"Maximum GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py new file mode 100644 index 0000000..0cd8d02 --- /dev/null +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -0,0 +1,313 @@ +# This file collects all the relevant code that we covered thus far +# throughout Chapters 2-4. +# This file can be run as a standalone script. + +import tiktoken +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +import matplotlib.pyplot as plt + + + +##################################### +# Chapter 2 +##################################### + +class GPTDatasetV1(Dataset): + def __init__(self, txt, tokenizer, max_length, stride): + self.tokenizer = tokenizer + self.input_ids = [] + self.target_ids = [] + + token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'}) + + for i in range(0, len(token_ids) - max_length, stride): + input_chunk = token_ids[i:i + max_length] + target_chunk = token_ids[i + 1: i + max_length + 1] + self.input_ids.append(torch.tensor(input_chunk)) + self.target_ids.append(torch.tensor(target_chunk)) + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return self.input_ids[idx], self.target_ids[idx] + + +def create_dataloader_v1(txt, batch_size=4, max_length=256, + stride=128, shuffle=True, drop_last=True): + tokenizer = tiktoken.get_encoding("gpt2") + dataset = GPTDatasetV1(txt, tokenizer, max_length, stride) + dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + + return dataloader + + +##################################### +# Chapter 3 +##################################### + +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by n_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) + + def forward(self, x): + b, num_tokens, d_in = x.shape + + keys = self.W_key(x) # Shape: (b, num_tokens, d_out) + queries = self.W_query(x) + values = self.W_value(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) + values = values.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + # Original mask truncated to the number of tokens and converted to boolean + mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + +##################################### +# Chapter 4 +##################################### + +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + nn.Dropout(cfg["drop_rate"]) + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + block_size=cfg["ctx_len"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"]) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_resid = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + x = self.drop_resid(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_resid(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + self.trf_blocks = nn.Sequential( + *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + x = self.trf_blocks(x) + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + +def generate_text_simple(model, idx, max_new_tokens, context_size): + # idx is (B, T) array of indices in the current context + for _ in range(max_new_tokens): + + # Crop current context if it exceeds the supported context size + # E.g., if LLM supports only 5 tokens, and the context size is 10 + # then only the last 5 tokens are used as context + idx_cond = idx[:, -context_size:] + + # Get the predictions + with torch.no_grad(): + logits = model(idx_cond) + + # Focus only on the last time step + # (batch, n_token, vocab_size) becomes (batch, vocab_size) + logits = logits[:, -1, :] + + # Get the idx of the vocab entry with the highest logits value + idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1) + + # Append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) + + return idx + + +##################################### +# Chapter 5 +#################################### + + +def calc_loss_batch(input_batch, target_batch, model, device): + input_batch, target_batch = input_batch.to(device), target_batch.to(device) + + logits = model(input_batch) + logits = logits.view(-1, logits.size(-1)) + loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1)) + return loss + + +def calc_loss_loader(data_loader, model, device, num_batches=None): + total_loss, batches_seen = 0., 0. + if num_batches is None: + num_batches = len(data_loader) + for i, (input_batch, target_batch) in enumerate(data_loader): + if i < num_batches: + loss = calc_loss_batch(input_batch, target_batch, model, device) + total_loss += loss.item() + batches_seen += 1 + else: + break + return total_loss / batches_seen + + +def evaluate_model(model, train_loader, val_loader, device, eval_iter): + model.eval() + with torch.no_grad(): + train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) + val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) + model.train() + return train_loss, val_loss + + +def generate_and_print_sample(model, tokenizer, device, start_context): + model.eval() + context_size = model.pos_emb.weight.shape[0] + encoded = text_to_token_ids(start_context, tokenizer).to(device) + with torch.no_grad(): + token_ids = generate_text_simple(model=model, idx=encoded, + max_new_tokens=50, context_size=context_size) + decoded_text = token_ids_to_text(token_ids, tokenizer) + print(decoded_text.replace("\n", " ")) # Compact print format + model.train() + + +def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir): + fig, ax1 = plt.subplots() + + # Plot training and validation loss against epochs + ax1.plot(epochs_seen, train_losses, label="Training loss") + ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") + ax1.set_xlabel("Epochs") + ax1.set_ylabel("Loss") + ax1.legend(loc="upper right") + + # Create a second x-axis for tokens seen + ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis + ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks + ax2.set_xlabel("Tokens seen") + + fig.tight_layout() # Adjust layout to make room + plt.savefig(output_dir / "losses.pdf") + + +def text_to_token_ids(text, tokenizer): + encoded = tokenizer.encode(text) + encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension + return encoded_tensor + + +def token_ids_to_text(token_ids, tokenizer): + flat = token_ids.squeeze(0) # remove batch dimension + return tokenizer.decode(flat.tolist()) + + From 0d517e98b94e77ff453ec3b710d4418307b44f2a Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 13 Mar 2024 08:37:54 -0500 Subject: [PATCH 3/3] update --- .../2.pdf | Bin 16780 -> 0 bytes .../mha-implementations-Copy1.ipynb | 850 ------- .../mha-implementations.ipynb | 1996 ++++++++--------- 3 files changed, 998 insertions(+), 1848 deletions(-) delete mode 100644 ch03/02_bonus_efficient-multihead-attention/2.pdf delete mode 100644 ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb diff --git a/ch03/02_bonus_efficient-multihead-attention/2.pdf b/ch03/02_bonus_efficient-multihead-attention/2.pdf deleted file mode 100644 index 06ef06cb0c6d0e3418038a49286766386157f48a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16780 zcmb`v2{={V7eB6$T=SelS4f0AU#=L~Zu4O-q_6gXzN z1Hx**(_s>k0+(Gm)eNL)n^7DoBsg}Z!Pt>PA^Ca1anLCYVdm!O?C#|X$FE-X^Cp^+ zC~#XqRznN0LJFk75xSm$1+~?q`sz^^e(*1LQ0oPN_=dRmBLVJK*hiR>$lm^bL?92u zKcH_&a&~uA^9}?Qk>C%DN6N^-F<6`o9t@9A2ZDjS$#B97d#YYu-rx)cb^Tiz(B(hc zGaz}nQrzIERf<~f9zb3=LdydvL4!o}b|yh=lPP{AM^9K_cBf^lk*GHN-3iO_MCF8s zN$>A7Z0mA}Bee!51S5GL7*?FkB25Yrm+sp0+FGCvYVk58s2nK28WzZbNE=+J|M~vJ zt6w*2iNnEY+xp1NVrByjqm7d3BrRGL+6Up3vtMt=lpmK}lt}ENLzma`>I)5VbhDXE* zqbf;~uln6T%XaO2tOZ9g&&QUFI(<6Xa_@dBPx|&d7DhzZkbC~WInVU4>wQ!It;a9p z_Dz0#c+|$uWdBv1Oxy+DQ|}i|isHjedpVXOW-^k}IhMrjLK%!CBF^j^;2-hMImMg4 znOi!aw|rCub6m)OVGE*(@rcZ==Gy@_i-}9mt9pyWo`l&mOB$T6HKBi)GR+>Nk}ov) zVEjQin+wC(%<~gd1EQBR(6NGp9@kEZMreqV5RojEM-zk37+(ugxQA5u>GGb|c6aTm zr(K=Dh+k`mM#3|HO+|SgKTf6u4-NItkGfT75A2)%t^M`n@N{NQf+O?uNr|U-E>z8x zm^f;6KNA_)exF`lVEdO=^=ML9>>&BrQLTmSuxL5W2)98F+r?*lROU9RrKcooaNP}W z5GA+xhJ`4Lj-Bv#nXu$Lm17^!95B;t|H1R<$)KtQEho|NGYba!%CZgi2vou2{1fUb z5+ayUbQV$TrN$|%#uCo2(iIn%1}<91Pu{-oWjE9zWn*xy?<=M@d#7~GSmRj9v)A9p zxGH`97yawy6>Z9D+Mmro6MNur+Oa5B=%&5f;76Oq?|O!0&m%2G=j;sP`}OiKO9@;# zEDeX5jUS0vSg33sYJJkzboe{I7GaLrGNg`#4hi%yU z5cl%PcGiJ;B!j!X?i$H<`{>o)Ku% zKo-3#wYexOAfF?i_nCgXI?LAGT@C5tbM*CrT9@*-bRD`9{94x@{rrxAae}&i&IsRJ z5|3}Izf&{zL#)AkUJ?fl^OB?8_PKk(CZXdn%_k9pPEXTS+oN7N9;viR9ml#_+!blg zxLd3t>5nzqc4A4E+tzMc>}BV@=k-L(3BC(T{dcT{I$kX8{xLQrack%AM%oYZoyrP^ zS^LMnPBe`y#!MiTs#p)*{rNjnI;@LLpJ{e)yZiYqTem#l^_+Y0{<~(YVA>xg*+~Za zYNw9{-fXFtD^)C3q3O=$jDPE9MX$i5ws_9~R!RHWPQyp5!r!v2MEFII#;F<%e_y7- zD9l&l>_SubC-xF;QDgbgE5|jgi8r<=aU;b#VtaSyTM_ixX?hMV`0v)Igb^h&%=dt-AJQ*9EyV6bgs?e>A~(l)Z(Q~< zGcYwOI?8;wP*zZ3QmpeAwlzp*`!_btrKX_HJu6QeZJ+f!ZN*q=~l_7-rWBFo=# zYSlHPF}MxzXrk*ygW6gAVK3lseXYwlW8XG?J4H(CUt46!q zX}3=ydlV-btChF;60cZabJBS~WO}#0@5@EWHtpv_Qn!r;di93#T&H#I0uJr>bhqE; z&n-)z$huRZj$wY+TPYd&M6Vy`C8AObgyp0|cE;qIcd?dp>^|ut8rYU&G?S05`=Q$N z{qAv9ZLi(VlDCbuBy&o!kIv>U#JTM*o>>ab`F^+lH;uZXc{LTCQ3+>q8<ov;=o)qFtGA0jcSYgCo zL=@OBv?CictXitdWi7Ay)$DZx-@s$lz+?#<;7Jx|y26vo<=y|`iR9+a>hdPUdiSF(Drv>V5}ZuCn0lBT z?~gm&tcW4>($O)AGF0IiTx?k_w^(kcd7dJ`Ooc57eymZI9QJ>+g^N?=Y(Ua!&KkYm zDy%wqrDsZQ=AN`}&Ff>~<$3+xuDm#PVsk8fB9E`-cC2Q&gl?d~6}6~9`5)DGC*5!1 z0?KP;C`BEx*nJ`+p+9N66}C%J;SGX{Yj`6d|2J<+@$Hb?P@gflzxB1*=AY&V*jj)0 zojwtCpUt-A4kIt~&iogVskGi<72^I52M^KmsPXbUkH1vli`8@Cy_dDSnCq#i?<Td#1t$PuaF- zQo^~3?FbD^%5-W}^P`lwSvTVT2x+@Fle06Ey+PN$A~|Nh(hjRCJ5v!E0_1B%M&mZ* zDD}NGP+MP4aN9W)sLY8@WFGF-Jig@V`Sdm;?@nU!3o)}@Zwjjl)AXua)8qZ^(1s_Q z(MO#5llDIJ(Dg{3?Jnco){{Lt;$SybN@JaQUA=3H|GP${RP%UJ+fdleWX1a~rCW5o zjz9JYyLnjcQEgbX{(ve+cJ&P%wcPZc@ro}(!y>#I(s!Kll$OqZ`mWO4ubCnbh4NME+Kw`6M59xc9^3rhxdqh+*=n zaHee)3(omP?^u5fXJ%De&xMPU>TILW^80UXy!+#XS8RH_>btp?_1M;dqIHgme1{;xmy*GCb}%tX^RCE(Gx@E&&lOPhSQXq6@GtoF<$QXLEh<+rux{L@D=8}p`I=vM3@Ih)d(oo6L!}%k!``qk zkn^=*PffG;=wCq%>GYq*_=$gX zd-S%qWi;XUt2jFSbJyI2zmAl9J(qE)Gn3bJayIfSb&ow@G;7S=Q{8x%nf<$mrn&?< z@!323#>Y%Y(mZ$r^%m%b@`rZovsw(iEHauYQyl1?-D%TD*-1tBR6;r$yMcJ1nP})o zFT6EkHYX49>FT(pQdfbgwZX@GES+_tdu<}$()MiT(*DG$J{^Vk}<_mNX{37aXV%d^CN%-NQDF`@3GL_=QW@NshKpM^lcR@aNiJUNLUoDQweO#lgIT2*4%Yh6(QQGR(u^Hc;T%^NbMg3UnGi~GEsv=f(N6d*RvKfJzBSa6*CE@aSpLF%JT$I`$sr{gA zM3*}>{q>=Mu=X8~Nc&hZJL}Rd9v3`HW^U2MD=BiI`^DJhk+oa=`(0kd)C<3^6fy18 zO6%$4{Pt9muFb4!cq)w3vt)t%vy{p6=L;^QfaC9U-*81wq?I<1#(0vOthOkbRO_RH z6W)2`is#ntxft%Dd^lPoeUIx%%>Da&x>={IKe&z_5OzB_bj6ccq@~x~bC1Q)Flp<& z2BT{g_9O4_Q(XY8{$w)}Hs_$CFe;b`joAQ##O)^}(FrrRAB$8*9@zF$<=&MmRrvn- zQ#DDMMhTyUTYIWQoF1eX7F=otPipGKQMT1$?GcflAh?E{gBY=b@m`vMVx=x6*DVsWG-_ArK&Ln zPBt(2)*6p~J1Lhao)bIs z#o#*_d7qm&eawCnVShJ(B6MpHCs{2^RfbErb4v0udzNu&W&R~o6Q5$`E0d)cPrb3X zlcu6CDqszb-2jMGGFFJ!5N3W)h(vr6{UYKcq*WsAGBbRQHcP_T;A9m>w3{PTn;Df- zJ|00Q$Shdeyqwm8Xz)RFk;g5sZqAAmt@rhjWg4vjTj?z7+C5GLE zpATRWK_P_DgJ*YqrZ3i_SU<$ztmrHU>2@{`Vnk4r@|2YU9u%$yVOxo!W#*(kM{ zBGfULee-cRfv0A6EC$bD824gwdX~#w5@A{^XmjDDKWFODwux<=>8_{vhn<(^8Gc9E z?4qJUDv(lk1MaO7S!IkgMq-g>NUW?WQWkHDMBz+yP~)2!Lo=%FkoN_U?c0i^7^Gqy zRGl#m9wKIDMMCUMdClzCIPZtGxcC%qqA{uG=cU3CH3*5>01z-dr^-r;+DbcfEb^iX zTqUr};sD#rC42q;fF35H)({VqEumsjfl)hNKY7VEtcmAw9m}>im2eWy*za=@X(kXS zar6uET!?SVVW}%_frs=T8#eMuALWuREK5)3_rCKneO5cwGm%BJOPWco&m(Cb9m;TG zTZXp(n0sOK&<~IDAB!}PJ+Qh|1fm8IH)5oQCc0;6fft04z*5Gzu8F_~ele>u+tc$KLdYdZDWyI_+z-m+Q;LZ`J#!;ImigANp zKeb#)yCYc6IjaqB3yFrm-+W`6R0c7`+=pE#m1qpc!x=B*xr=qjFdc!hgUmB zD&Ia2kX4udGEx-$dBM#h-68Uc{&%ZD`$*-C;>>`LNoql_4Hb5E2*C=_3etOT9e?=F zD|BIByXmK2t+R82t>>tFqV55S!2-@> zORTQD6N-*KAzqkeBK{Onq{1Z?*oeVyz15Z=;>`unUjMOV0P(zo!x_=uNqqd8C$ ztNS4TOORJUo=5(Hsd)h!3@g&FAzI8Q#_nzZ19lbUg7e(hB;u~u1Zc@}X06{=Oo zhxhH4_`5!eO_@IWa*sBvoVh!?E$EY2|67FRwdkx*jQ%ywh`I+kFWye5C2u`iM#Fl@ zxv{JyzbNg)lw#NqS4}^yv%Rr$^1?|rPYGe8PcI4CjM>RnjCzVxR9g;z-1#~iVfW)b zpZ;{jLn^aF1w^Cp8!@7UW?y;$qGzsd7DI->zLz$CWqbd-{PedT)u3%^%uf3sT(`6I zb)+<&yDg1jIi=12@zDaBSoh2dE@n7qr%Jtw*Mh|u8|mJtUz z^j<2BY=Gp#b?WRD`91S~s+ULEi>m{W^9N=V@NBmU7v${Jy8Y zIHp=nhEIq2VcVuk=8&6W;HA22NCw=Ky$Q~%QIQ$LuOXPdDY1P-kCsO4C2#9`ey~)4 zAa(-_7z{Oj(%3Cz_za{NIyOH^^MUQkxu3d)rrMQ7oC;ZLa0(~)zZXuLiTZwY@U0rN zIbk9pi{pyPdL&f!5}xI;Gd>`_=vRcE$YoBc3GQRa;MOml&Q()K2^I=F=(#obWV9Zb ztk_(lVD>QcNX+5)MAG=39)&GUJ02e-6g&IopY=EQm-I2{O8ss!G#_A{60FTig7ue@ z1w2ODJR7%(Yn?Cq$T}MX-!E>a-Sh!I@QFb@E;zMhug?XKA5Zj=taQKBxH86LRo*j) zKI%04HqJRcHWn}u>fn8hLuY`!B~f^rU|7HfyQR_R^LLkgPb+lY;=d*^GkP@Rc|fzh zQxXq(Q|zEb(ZmrMn?8Q|SIL;>@ykC?KdyF2uMIbtncw%kYvDJokc7(u6;)C}S|~hb z1FB3=;G`91RzGf(D>5fStK8*+mcY+?6Z<_(8+xqZ;k8_0y)VLICc~nt#D?f+%II8) zi&JYcV>h0c!j*0n%3V1&4ZTmjmMyiOO-SZx96cWrep0M+46c)S`_y@b7sWa~bvq9nLKoV} z__s=0JANrt_dj#_M2^elMb_^LV;MpN!Zl{k(~fWBO^Z3ECs>gsHS5EKucyi4Br|=U z-TG;7h)ILSC9lu3Y&F)LD%-gixhJzq(yW{|v&2T04e*Rw3`tnsDVYlpt!7~KLs?In zv1*=?xZQAqUuq!HCi}ggHEhW%c=LA0NV)0CSdEcm4YUTLznw!ox|D6F-@Vhs2b`9h z>JEtwIo~O|N!3iQMEQ>6&%?EECo=bl%MqTJ=b`FZjNa;-zOt-t?z*J2A@a_VK zzACt1Gut-i`<;C|&+6QXNvv;iOIE+WkXg#j&WR^_dDo8~ z$}VQ5}s%)4ynC;==YABA~nn^KVlyGW|_w4-aFz`*oAl_ zZq`r5TB*QGxs5!S|9newpgXzy2l(jl;MF}Ab}_Fz;zLSyf{d9GUYJJpEji3mrw{su zeoI3#*=ZZcUzpB!d{+OCF7Eqf)2}mZDLK=nqo?0*+N`$Sg$ip_a2p!4f#sQ@ajPeo zy`Qx(9I|0jpjL<>R`JHKQ|fJuN$GC-qq$#FY}49EvmylV)|xl!QAVlmU%a|I%PW58 zGj9%9GB#%~6RXhcuzTuNfz?F}G|I)LAGsAjQG7hYIzRB!SbvX1=#&!rzDwm@08X+ok!*GdK6McX=7>>7plxQ{~u@TxL&FeM=HUQ@FOMyJ}lvnF?Flg4A#i;t08!cN={iXxl3z3JS|i z>)1$}CO02^AwP!I3yY_DFudpN_k}H27Z1r&(Hj*ciNbB6`my$q}GI#^1tSHbF|={p zf_UD3M-pR_v@agef4_;IYwqd%sgGVW0l}Lbnvl!iJpcQd6iW??;Whx2ap&4ufOzUN zSz%0c-CZ{n_n)~GMOiu`c7>if1KBD*z#E=;R^7m1I)nIH%zS^DvZ%=I8eIZ=v9T=O zem4VVOm>0(Na|EQN%3|dp{$qwLPE=swI^1CzfZ)2W9Q?%cCe)qXwGX=5-?FR#q&!|iv74)?c@i6$CD{E?o zB(2jbX^AmvXN385XdZUF31fU4Vhuj3<;gKz+a0wxd@l>l>zgbY>q!G6S?4cN_zhT& zW&(i|ES~Aecf-VW?80VC+i%@pSSn_CuAMiv87ol3YuFaV?WR0cp_NB|_3gKsU=MHHO+MquC)<9P9|6ugi+)W=5Q&Bun-{o++V;0sY*Zh^m zAWZf>6+x*$I}Bz6a9-L}->(VI{E`eicnLGdFs?X7sR_KhRF#yhsm1E>4fW=9jNS=r z`0f1X4{yhHs6@UlFX@Qv(s|h;AKji1dnZXc^0&q6};r5&x6437{4 zN7?rN=#3x=k2UeXaO7!r`BBa(2jc$ea(g4T}Ru+ekbLVl98j_8#;zFwyy^TDx$a+SZC&0KCB6YuiN>n+|oQ{k%L-A%Zbb2rV}jdGr=t{|f8kfdGO;Y3&WGd4_1twy^*XRGe zvjMDu>A}aaE{QuTg`f66{(ALlZ2eE~BV2cR43G|E5(w+3`t`kHFAE#TZFkO!Rj9o> z_aUppv}u;NlDkW~T)T_owls{+giFaP6T!2qx9i)ae36gd#Zb$wx4C!Jh3nFUAD&|n zD#XBzCsZEQI-yNz!$kYZ?!H@N{C3BjDyYjD7B4Hv{ibA{**=`CU)U{ZXf(O0sjW%( z(m~z%OPLvZ8cFHzgDcd&oo+}{RMtkdhD_sQJoD&xPY4x?bv)#jzuvsL&L!*Ffk}r4 zLQxWDEtx8dggn03a-|jsmW}ZbAYZj3&75*yU#IVV>qH9s_+|`O{vi5;N%Fx)pO0ys zN4Gw4`@Ms=ZG^~^6rz2lf|JEvt|7W1=a5{}p|4y7)`n=}g)Av%qTB8(LGdvS-8q$o zj*ps!Y`%YGGoLyUMP(wWVAze|rez-#(#Wk2zS|4Fr8T88G8IUM!fXU^oYQBe69$W_ z?7aAf=A9G1=&4BZ(ua@>UR*?_Nb}I@uv{gXBeWNIb_(-XeZF0s7O!coY#wB4!PYtw zRr;u1c<_?cr!ZZ~)56BF8QT-3NnW=@UgBA9+(N;;h0kkW<;B*YD%gZ#q#_L!^n?M+)eRUd!3#i?_rGAl@DQ(;GXntYFWzG* z9^j_+rI@BD*q8^jM@5?7nx@R(Ont#AJARZD`H-=-g!6;xX`P}v3*l!ondtJ{I9IwB z)4~Vi1>?T0Vfz{juP`;wS5=h1X%RdpWW^(rtI>dNdN{c~?eWL!F19Dr+#AjXMr7}O zF5PTJZ@Obh?!%ec++)%=wnX2}OgjS;w20w7srvbnc3kj}ZRS?+Cm#fKea^-hCx!%s zzR+PN9G40(+4&8#2gNn{Vu$9SYN%LJ2BE2j@@;soyW-OC{a^PnG;KjOM1jm_JVLaFVn;#@V0s~ODC#X_$zBpBR`|w zf+F|5b7EYADJ{dSEV-xbh4Rm4<6mulo43!ChkC#*XCY*O^^P*%r@A z3{Ez^M)AAe?j!ZGPaiB-%2aJFwR|icr&JX;1quL%Ac}(^jQl##4k2`pSs& zAxZ_leaF-CnYfpJQPFAm<;C`$d?R@{P=DevZ_n8A&eksv%pID}_j#${pN$UJMUVYF zTRt%{zesEEU>i+Eht#m-M&9rnCg`%NgN6HS_5;462u#C-;=QXiZZr4F)EWFk7suf~ zii+L7qB0cc$+Pi$&kKw-RKkg)Pvn+b?csHWiGC+W4{d&Qlt)3EMd;)Ttmq zxecs)p#6OzNM?=W`(Ja6-Wy;LhV|cT()A$@4<#Qm9aVWF&k=t)R9y9Y<{|jN)Funf zQ4`&cCBwVLnr2;>Ka?2KiDTlw`wJ9h^#pa>?D0(IedQBOlP&mRJ%a!euEU8NV$iUf z12FhrDN?)PoK78Evq{7++vx~wn$bm$HijePo5EhlA%y4h5d0ChnAjwL>QMSMSX_BT zxf;u*K3&b`@g$3J4OkoGSE{yuEBoLbE>yE`C|&?Yg+VI#4z-b&=}PDBgC!8y>ogBr zQYq+{$UU+D%HgPwOXVW^wM;>qiDx6byLVj8YCP?5eyr)KnE8=8?j{v6*_KQP*203P zg4&5ePIBAqgl_JLAFY^N;HXV<3GRNRRwL#Xa=g^ceekj`*F!-Z{X2{zw*vcgG9hQ5?yUhBQu9VcZ&(q&MD)XC)R8-&*8nvOlLIhUu7HxVKY^oGJiwNAG zx2N;>z0R{wEQIdcoxYhOocYBxxk&hpnQ~C3Om@v+U&mqWMA}Vha>$*kIz(AgLySvw z_vkMDKbYB0QG~081kWTr`Z7;SV;;qp1D&dJ3xYHhfDb>zmO)g%~ zF7X1~J@69aRr-ga7PG=fY1=}#Y_gCO=$q{Ol&Z9U_Y(>ay0-+^zKxMfbC=?L{N-m{ z?`^P8(VV@HI7$!G)+zi#8qv^j`L=I%zNnW{IsVYQ;v3k!zVLg~MQ!u6_4)y-_}45~ zvMH8RBoJzjWYTKQ(8NT;Vy_12u%m^)nWGmO3e|pOin^PlAGjG|;J8``d$3>#=pGFc znds;4L-F>5qaf^J3P+gvJ5iQ%H6R)Qx)?cnf{X{XwL`DmueOrqAdNzzQE+K2c~XaIUZsslvA(2?TjzMKjngG54>3;zD89fTnu=8Qpx2FQSbLK;Y)(DVWc4N(5Z zzj-5nnqUYGcNZ5D$i9HmEo|Y?h6%EdBN1eR0Qr2NvUvfz>-Ob&2j3`dy2p{eq^Gz9|=wXJssf)Cpdx#M>qolfHWKd@?zZJ2zQVQ0{HZRBRpXUFTg5bAC3ak zg8&SIDFxE`gLA;ZQ8*$HjyMMPPJ`qRgr&PP#SLf?LbreCkt|Q_pUq?c8qQx+zdTR> zhj|7vdN{h0fu=V!-n1NI;a(1|n}`b4en=%Yhgo)9DqVfCj9@@_yzkULS; z%hiJf7)O{q6(1n`paAQ|!*Rf><-kKK3mybnIlurC)S)AE z9c*C7gN!Fxa6dFA0mMAK+)5jCTowi+3xkdj6Od3)!2^SX?!^JYSH=K+0o~;~c69B-%gLXU$4Fp`Fiv@IXz=L22(0y1q4jkoxvY_K~K|}rU%SwQb5dGyz zgA@pO01AacG;!dbWu^RK4mvK|4@3zIogg642goM=lx164wF$@`{w@$- z%jaMW$Tt2INOC+92j&p)3fTaaLXZP<0@(!!UVoQAHu2{ODFw_sZ21V)|0>IKy>cDu z^G|{70-9Cm_;2#a->V{35R8_6A@J8giObPV6Am2ba!m`qyz}1_a^GNPm)rE=;8WIe z%@7VI8Ps6Fm7`Z9qX`^L`0`bAIC?o&L7pAGO2ZM3UZvp#^t#fI2w&dj2S#xQRlo891GT9#|>;Hg=uc>w!a zsd>Ty99+KY1qaS^r3M8v?5a@S@a3;ZKzJw+VpnPTK(TVU^S_Z4e2V_JWi_BKQ-=K$ zR9DIS0W)3cO9oHUY7L5{s|>8fY22!~6gXINfRSM<0T`Ic${kQl2A*n}rau&gS8D%; zPVkL-x#=jF`ITBAutTT@1ECN+Rm)eGqdb0v^1u0|5bZyaPzyYND>VPT4srD7x+>~# z{P9o5DMDQpbV3jv;V395Aym=-$8G>Ptake=d_g2vN&HI}l6PGrB>K8WNTNR+uEg?{ zOj?A#yE6ocz(Wbizno_34`S~!ldCCM!14ckvpPr>g<#(vfX{q9yeS^;PVfMXjI0b& z3hqXs_>koh2+w~m$$0y@N`Pm<&)eCb2r^IqzRSni1@7cXJOZe!jsSWt5AW{nr2%E> z?$MA(BhgqS1_iceW6>zIG!iF{M2buNrI$cIk_!xY3M>o+wbdVpSvc?Svr!`8MTk-#+8^g}|A#F{n~^uVlbgG^{`8}MWQZ3AZj4z6v3-ml%(ZFFKXTPQ&R(9vcsyucu4E06e&6JOUnMaQwFoKqY7$1O5Je7QhzU_5HA5ZMC+a?0O%8Lap~D zDC~NFfWiYjv}Qb%-1_Gd1!1i<{m>{VibKDDc}8Qv_SQ9R;C0$xHbn7rboU_nE&ta9 yGxs2{0tHJyQ*Up`nJ>pST`w1J0DYEEKnx*M9Q`QE?h}Q>qLDB$G0pv2u>TJ!#>K7x diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb deleted file mode 100644 index 41a4801..0000000 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations-Copy1.ipynb +++ /dev/null @@ -1,850 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "6f678e62-7bcb-4405-86ae-dce94f494303", - "metadata": { - "id": "6f678e62-7bcb-4405-86ae-dce94f494303" - }, - "source": [ - "# Efficient Multi-Head Attention Implementations" - ] - }, - { - "cell_type": "markdown", - "id": "b742938a-4bfc-4527-a1f1-d5963508967d", - "metadata": { - "id": "b742938a-4bfc-4527-a1f1-d5963508967d" - }, - "source": [ - "This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "outputId": "7d088260-3fa1-44f2-bd65-2a46e289f9d4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PyTorch version: 2.1.0\n", - "Running on cpu\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "torch.manual_seed(123)\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(f\"PyTorch version: {torch.__version__}\")\n", - "print(f\"Running on {device}\")\n", - "\n", - "batch_size = 8\n", - "context_len = 1024\n", - "embed_dim = 768\n", - "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" - ] - }, - { - "cell_type": "markdown", - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", - "metadata": { - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" - }, - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "outputId": "f8a33752-2cd6-4101-8feb-9d1699984719" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", - "\n", - "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim//12,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_ch03_wrapper(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "21930804-b327-40b1-8e63-94dcad39ce7b", - "metadata": { - "id": "21930804-b327-40b1-8e63-94dcad39ce7b" - }, - "source": [ - "## 2) The multi-head attention class from chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "outputId": "b704a040-3547-422c-ecda-df9982a2da35" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "from ch03 import MultiHeadAttention as Ch03_MHA\n", - "\n", - "mha_ch03 = Ch03_MHA(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_ch03(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", - "metadata": { - "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" - }, - "source": [ - "## 3) An alternative multi-head attention with combined weights" - ] - }, - { - "cell_type": "markdown", - "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", - "metadata": { - "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" - }, - "source": [ - "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", - "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", - "\n", - " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - "\n", - "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", - "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "outputId": "5d948671-176f-4633-bede-97767e36becc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "\n", - "\n", - "class MultiHeadAttentionCombinedQKV(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", - "\n", - " self.num_heads = num_heads\n", - " self.block_size = block_size\n", - " self.head_dim = d_out // num_heads\n", - "\n", - " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", - " self.dropout = nn.Dropout(dropout)\n", - "\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, embed_dim = x.shape\n", - "\n", - " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", - " qkv = self.qkv(x)\n", - "\n", - " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", - " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", - "\n", - " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", - " qkv = qkv.permute(2, 0, 3, 1, 4)\n", - "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", - " queries, keys, values = qkv.unbind(0)\n", - "\n", - " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", - " attn_scores = queries @ keys.transpose(-2, -1)\n", - " attn_scores = attn_scores.masked_fill(\n", - " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", - " )\n", - "\n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", - " attn_weights = self.dropout(attn_weights)\n", - "\n", - " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", - " context_vec = attn_weights @ values\n", - "\n", - " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", - " context_vec = context_vec.transpose(1, 2)\n", - "\n", - " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", - " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", - "\n", - " context_vec = self.proj(context_vec)\n", - "\n", - " return context_vec\n", - "\n", - "\n", - "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_combined_qkv(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", - "metadata": { - "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" - }, - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention" - ] - }, - { - "cell_type": "markdown", - "id": "f78e346f-3b85-44e6-9feb-f01131381148", - "metadata": { - "id": "f78e346f-3b85-44e6-9feb-f01131381148" - }, - "source": [ - "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", - "metadata": { - "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" - }, - "outputs": [], - "source": [ - "class MHAPyTorchScaledDotProduct(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", - "\n", - " self.num_heads = num_heads\n", - " self.block_size = block_size\n", - " self.head_dim = d_out // num_heads\n", - " self.d_out = d_out\n", - "\n", - " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", - " self.dropout = dropout\n", - "\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, embed_dim = x.shape\n", - "\n", - " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", - " qkv = self.qkv(x)\n", - "\n", - " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", - " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", - "\n", - " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", - " qkv = qkv.permute(2, 0, 3, 1, 4)\n", - "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", - " queries, keys, values = qkv.unbind(0)\n", - "\n", - " use_dropout = 0. if not self.training else self.dropout\n", - " context_vec = nn.functional.scaled_dot_product_attention(\n", - " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", - "\n", - " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", - " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", - "\n", - " return context_vec" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "outputId": "af9e4855-7f20-4d61-8532-4827df8dfb30" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_scaled(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "351c318f-4835-4d74-8d58-a070222447c4", - "metadata": { - "id": "351c318f-4835-4d74-8d58-a070222447c4" - }, - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention" - ] - }, - { - "cell_type": "markdown", - "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", - "metadata": { - "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" - }, - "source": [ - "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "outputId": "2a085df8-0445-4818-9978-6dc74469f568" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "\n", - "\n", - "class MHAPyTorchClass(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n", - " super().__init__()\n", - "\n", - " self.block_size = block_size\n", - " self.multihead_attn = nn.MultiheadAttention(\n", - " embed_dim=d_out,\n", - " num_heads=num_heads,\n", - " dropout=dropout,\n", - " bias=qkv_bias,\n", - " add_bias_kv=qkv_bias,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.need_weights = need_weights\n", - " self.proj = nn.Linear(d_out, d_out)\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, _ = x.shape\n", - "\n", - " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", - " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", - " if self.block_size >= num_tokens:\n", - " attn_mask = self.mask[:num_tokens, :num_tokens]\n", - " else:\n", - " attn_mask = self.mask[:self.block_size, :self.block_size]\n", - "\n", - " # attn_mask broadcasting will handle batch_size dimension implicitly\n", - " attn_output, _ = self.multihead_attn(\n", - " x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n", - " )\n", - "\n", - " output = self.proj(attn_output)\n", - "\n", - " return output\n", - "\n", - "\n", - "mha_pytorch_class_default = MHAPyTorchClass(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_class_default(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4", - "metadata": { - "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4" - }, - "source": [ - "## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`" - ] - }, - { - "cell_type": "markdown", - "id": "d2164859-31a0-4537-b4fb-27d57675ba77", - "metadata": { - "id": "d2164859-31a0-4537-b4fb-27d57675ba77" - }, - "source": [ - "- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n", - "\n", - "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n", - " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n", - " and achieve the best performance for MHA.\n", - " Default: ``True``." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", - "outputId": "234771f4-8a53-4478-8a9b-cf19f79a5e07" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "mha_pytorch_class_noweights = MHAPyTorchClass(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False,\n", - " need_weights=False # NEW!\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_class_noweights(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "8877de71-f84f-4f6d-bc87-7552013b6301", - "metadata": { - "id": "8877de71-f84f-4f6d-bc87-7552013b6301" - }, - "source": [ - "## Quick speed comparison (M3 Macbook Air CPU)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "194 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3\n", - "%timeit mha_ch03_wrapper(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "198 ms ± 4.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 2) The multi-head attention class from chapter 3\n", - "%timeit mha_ch03(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "234 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 3) An alternative multi-head attention with combined weights\n", - "%timeit mha_combined_qkv(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "outputId": "80c6e314-0771-470e-b090-628984ce2d85" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "71.7 ms ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention\n", - "%timeit mha_pytorch_scaled(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", - "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "211 ms ± 5.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", - "%timeit mha_pytorch_class_default(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", - "metadata": { - "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", - "outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2", - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "207 ms ± 18.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", - "%timeit mha_pytorch_class_noweights(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "dabc6575-0316-4640-a729-e616d5c17b73", - "metadata": { - "id": "dabc6575-0316-4640-a729-e616d5c17b73" - }, - "source": [ - "## Speed comparison (Nvidia A100 GPU) with warmup" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000", - "metadata": { - "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000" - }, - "outputs": [], - "source": [ - "# CUDA benchmark code shared by Andrei Aksionov\n", - "# and based on code from\n", - "# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n", - "\n", - "import time\n", - "\n", - "def time_pytorch_function(func, *input, num_repeats = 100):\n", - " # CUDA IS ASYNC so can't use python time module\n", - " #start = torch.cuda.Event(enable_timing=True)\n", - " #end = torch.cuda.Event(enable_timing=True)\n", - " start = time.time()\n", - " # Warmup\n", - " #for _ in range(5):\n", - " # func(*input)\n", - " #torch.cuda.synchronize()\n", - "\n", - " #start.record()\n", - " for _ in range(num_repeats):\n", - " func(*input)\n", - " #torch.cuda.synchronize()\n", - " #end.record()\n", - " #torch.cuda.synchronize()\n", - " return (time.time()-start) / num_repeats" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "CDJAPZaszaqx", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 489 - }, - "id": "CDJAPZaszaqx", - "outputId": "f23e9b83-7fd6-4011-9434-0e6934cf762a" - }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "#embeddings_cuda = embeddings.to(torch.device(\"cuda\"))\n", - "\n", - "functions = {\n", - " \"1) MHA wrapper class\": mha_ch03_wrapper,\n", - " \"2) MHA Ch03\": mha_ch03,\n", - " \"3) MHA with combined QKV weights\": mha_combined_qkv,\n", - " \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n", - " \"5) PyTorch MHA class defaults\": mha_pytorch_class_default,\n", - " \"6) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n", - "}\n", - "execution_times = [time_pytorch_function(fn, embeddings) for name,fn in functions.items()]\n", - "\n", - "\n", - "# Plotting\n", - "\n", - "# Customize further for dark mode aesthetics\n", - "plt.rcParams['figure.facecolor'] = '#121212' # Dark figure background\n", - "plt.rcParams['axes.facecolor'] = '#121212' # Dark axes background\n", - "plt.rcParams['axes.edgecolor'] = 'white' # White axes border\n", - "plt.rcParams['axes.labelcolor'] = 'white' # White labels\n", - "plt.rcParams['text.color'] = 'white' # White text\n", - "plt.rcParams['xtick.color'] = 'white' # White x ticks\n", - "plt.rcParams['ytick.color'] = 'white' # White y ticks\n", - "plt.rcParams['grid.color'] = '#444444' # Lighter grid lines for contrast\n", - "plt.rcParams['lines.linewidth'] = 2 # Thicker plot lines for visibility\n", - "plt.rcParams['lines.markersize'] = 8 # Larger markers for visibility\n", - "\n", - "fig, ax = plt.subplots()\n", - "bars = plt.bar(functions.keys(), execution_times)\n", - "\n", - "plt.ylabel('Execution time (ms)')\n", - "plt.xticks(rotation=45, ha=\"right\")\n", - "\n", - "# Calculate new ylim with a margin\n", - "max_execution_time = max(execution_times)\n", - "upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n", - "\n", - "plt.ylim(0, upper_ylim) # Setting new ylim\n", - "\n", - "# Annotate bars with execution times\n", - "for bar in bars:\n", - " yval = bar.get_height()\n", - " plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha='center', va='bottom')\n", - "\n", - "\n", - "plt.tight_layout()\n", - "plt.savefig(\"2.pdf\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d3e1137b-9acc-4cc5-bcbf-0e8533839f06", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "A100", - "machine_shape": "hm", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index 1eda8cc..d3500e1 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -1,1000 +1,1000 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "6f678e62-7bcb-4405-86ae-dce94f494303", - "metadata": { - "id": "6f678e62-7bcb-4405-86ae-dce94f494303" - }, - "source": [ - "# Efficient Multi-Head Attention Implementations" - ] - }, - { - "cell_type": "markdown", - "id": "b742938a-4bfc-4527-a1f1-d5963508967d", - "metadata": { - "id": "b742938a-4bfc-4527-a1f1-d5963508967d" - }, - "source": [ - "This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "outputId": "7d088260-3fa1-44f2-bd65-2a46e289f9d4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "PyTorch version: 2.2.1+cu121\n", - "Running on cuda\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "torch.manual_seed(123)\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(f\"PyTorch version: {torch.__version__}\")\n", - "print(f\"Running on {device}\")\n", - "\n", - "batch_size = 8\n", - "context_len = 1024\n", - "embed_dim = 768\n", - "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" - ] - }, - { - "cell_type": "markdown", - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", - "metadata": { - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" - }, - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "outputId": "f8a33752-2cd6-4101-8feb-9d1699984719" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", - "\n", - "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim//12,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_ch03_wrapper(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "21930804-b327-40b1-8e63-94dcad39ce7b", - "metadata": { - "id": "21930804-b327-40b1-8e63-94dcad39ce7b" - }, - "source": [ - "## 2) The multi-head attention class from chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "outputId": "b704a040-3547-422c-ecda-df9982a2da35" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "from ch03 import MultiHeadAttention as Ch03_MHA\n", - "\n", - "mha_ch03 = Ch03_MHA(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_ch03(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", - "metadata": { - "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" - }, - "source": [ - "## 3) An alternative multi-head attention with combined weights" - ] - }, - { - "cell_type": "markdown", - "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", - "metadata": { - "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" - }, - "source": [ - "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", - "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", - "\n", - " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - "\n", - "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", - "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "outputId": "5d948671-176f-4633-bede-97767e36becc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "\n", - "\n", - "class MultiHeadAttentionCombinedQKV(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", - "\n", - " self.num_heads = num_heads\n", - " self.block_size = block_size\n", - " self.head_dim = d_out // num_heads\n", - "\n", - " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", - " self.dropout = nn.Dropout(dropout)\n", - "\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, embed_dim = x.shape\n", - "\n", - " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", - " qkv = self.qkv(x)\n", - "\n", - " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", - " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", - "\n", - " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", - " qkv = qkv.permute(2, 0, 3, 1, 4)\n", - "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", - " queries, keys, values = qkv.unbind(0)\n", - "\n", - " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", - " attn_scores = queries @ keys.transpose(-2, -1)\n", - " attn_scores = attn_scores.masked_fill(\n", - " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", - " )\n", - "\n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", - " attn_weights = self.dropout(attn_weights)\n", - "\n", - " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", - " context_vec = attn_weights @ values\n", - "\n", - " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", - " context_vec = context_vec.transpose(1, 2)\n", - "\n", - " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", - " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", - "\n", - " context_vec = self.proj(context_vec)\n", - "\n", - " return context_vec\n", - "\n", - "\n", - "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_combined_qkv(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", - "metadata": { - "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" - }, - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention" - ] - }, - { - "cell_type": "markdown", - "id": "f78e346f-3b85-44e6-9feb-f01131381148", - "metadata": { - "id": "f78e346f-3b85-44e6-9feb-f01131381148" - }, - "source": [ - "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", - "metadata": { - "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" - }, - "outputs": [], - "source": [ - "class MHAPyTorchScaledDotProduct(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", - "\n", - " self.num_heads = num_heads\n", - " self.block_size = block_size\n", - " self.head_dim = d_out // num_heads\n", - " self.d_out = d_out\n", - "\n", - " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", - " self.dropout = dropout\n", - "\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, embed_dim = x.shape\n", - "\n", - " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", - " qkv = self.qkv(x)\n", - "\n", - " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", - " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", - "\n", - " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", - " qkv = qkv.permute(2, 0, 3, 1, 4)\n", - "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", - " queries, keys, values = qkv.unbind(0)\n", - "\n", - " use_dropout = 0. if not self.training else self.dropout\n", - " context_vec = nn.functional.scaled_dot_product_attention(\n", - " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", - "\n", - " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", - " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", - "\n", - " return context_vec" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "outputId": "af9e4855-7f20-4d61-8532-4827df8dfb30" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_scaled(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "351c318f-4835-4d74-8d58-a070222447c4", - "metadata": { - "id": "351c318f-4835-4d74-8d58-a070222447c4" - }, - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention" - ] - }, - { - "cell_type": "markdown", - "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", - "metadata": { - "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" - }, - "source": [ - "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "outputId": "2a085df8-0445-4818-9978-6dc74469f568" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "\n", - "\n", - "class MHAPyTorchClass(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n", - " super().__init__()\n", - "\n", - " self.block_size = block_size\n", - " self.multihead_attn = nn.MultiheadAttention(\n", - " embed_dim=d_out,\n", - " num_heads=num_heads,\n", - " dropout=dropout,\n", - " bias=qkv_bias,\n", - " add_bias_kv=qkv_bias,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.need_weights = need_weights\n", - " self.proj = nn.Linear(d_out, d_out)\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, _ = x.shape\n", - "\n", - " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", - " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", - " if self.block_size >= num_tokens:\n", - " attn_mask = self.mask[:num_tokens, :num_tokens]\n", - " else:\n", - " attn_mask = self.mask[:self.block_size, :self.block_size]\n", - "\n", - " # attn_mask broadcasting will handle batch_size dimension implicitly\n", - " attn_output, _ = self.multihead_attn(\n", - " x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n", - " )\n", - "\n", - " output = self.proj(attn_output)\n", - "\n", - " return output\n", - "\n", - "\n", - "mha_pytorch_class_default = MHAPyTorchClass(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_class_default(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4", - "metadata": { - "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4" - }, - "source": [ - "## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`" - ] - }, - { - "cell_type": "markdown", - "id": "d2164859-31a0-4537-b4fb-27d57675ba77", - "metadata": { - "id": "d2164859-31a0-4537-b4fb-27d57675ba77" - }, - "source": [ - "- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n", - "\n", - "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n", - " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n", - " and achieve the best performance for MHA.\n", - " Default: ``True``." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", - "outputId": "234771f4-8a53-4478-8a9b-cf19f79a5e07" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "mha_pytorch_class_noweights = MHAPyTorchClass(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False,\n", - " need_weights=False # NEW!\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_class_noweights(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "8877de71-f84f-4f6d-bc87-7552013b6301", - "metadata": { - "id": "8877de71-f84f-4f6d-bc87-7552013b6301" - }, - "source": [ - "## Quick speed comparison (M3 Macbook Air CPU)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "200 ms ± 5.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3\n", - "%timeit mha_ch03_wrapper(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "198 ms ± 6.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 2) The multi-head attention class from chapter 3\n", - "%timeit mha_ch03(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "236 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 3) An alternative multi-head attention with combined weights\n", - "%timeit mha_combined_qkv(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "outputId": "80c6e314-0771-470e-b090-628984ce2d85" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "71.6 ms ± 3.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention\n", - "%timeit mha_pytorch_scaled(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", - "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "217 ms ± 4.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", - "%timeit mha_pytorch_class_default(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", - "metadata": { - "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", - "outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2", - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "205 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", - "%timeit mha_pytorch_class_noweights(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "a78ff594-6cc2-496d-a302-789fa104c3c9", - "metadata": { - "id": "a78ff594-6cc2-496d-a302-789fa104c3c9" - }, - "source": [ - "## Quick speed comparison (Nvidia A100 GPU)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0", - "outputId": "e99a17e9-8139-4b04-dac8-fa1dd5027735" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "8.35 ms ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3\n", - "%timeit mha_ch03_wrapper(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "8686dd69-3655-40e4-a57b-a2c55532a010", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "8686dd69-3655-40e4-a57b-a2c55532a010", - "outputId": "5553b42c-b709-41a4-8a8b-be36dae408ab" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "6.59 ms ± 231 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 2) The multi-head attention class from chapter 3\n", - "%timeit mha_ch03(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf", - "outputId": "01b0da88-510b-4b21-919a-0a7519a55ed8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "7.21 ms ± 716 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 3) An alternative multi-head attention with combined weights\n", - "%timeit mha_combined_qkv(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26", - "outputId": "542706db-5041-45ca-f667-9e1bd1c2c7aa" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "2.38 ms ± 362 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" - ] - } - ], - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention\n", - "%timeit mha_pytorch_scaled(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d", - "outputId": "13cfc808-2b11-4041-fe67-e5a63abe4f28" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "6.67 ms ± 408 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", - "%timeit mha_pytorch_class_default(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8", - "outputId": "c52858e7-999c-4782-adc9-731f8d69dfa6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "4.54 ms ± 7.17 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", - "%timeit mha_pytorch_class_noweights(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "dabc6575-0316-4640-a729-e616d5c17b73", - "metadata": { - "id": "dabc6575-0316-4640-a729-e616d5c17b73" - }, - "source": [ - "## Speed comparison (Nvidia A100 GPU) with warmup" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000", - "metadata": { - "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000" - }, - "outputs": [], - "source": [ - "# CUDA benchmark code shared by Andrei Aksionov\n", - "# and based on code from\n", - "# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n", - "\n", - "def time_pytorch_function(func, *input, num_repeats = 1_000):\n", - " # CUDA IS ASYNC so can't use python time module\n", - " start = torch.cuda.Event(enable_timing=True)\n", - " end = torch.cuda.Event(enable_timing=True)\n", - "\n", - " # Warmup\n", - " for _ in range(5):\n", - " func(*input)\n", - " torch.cuda.synchronize()\n", - "\n", - " start.record()\n", - " for _ in range(num_repeats):\n", - " func(*input)\n", - " torch.cuda.synchronize()\n", - " end.record()\n", - " torch.cuda.synchronize()\n", - " return start.elapsed_time(end) / num_repeats" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "CDJAPZaszaqx", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 489 - }, - "id": "CDJAPZaszaqx", - "outputId": "f23e9b83-7fd6-4011-9434-0e6934cf762a" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ], - "image/png": "\n" - }, - "metadata": {} - } - ], - "source": [ - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "embeddings_cuda = embeddings.to(torch.device(\"cuda\"))\n", - "\n", - "functions = {\n", - " \"1) MHA wrapper class\": mha_ch03_wrapper,\n", - " \"2) MHA Ch03\": mha_ch03,\n", - " \"3) MHA with combined QKV weights\": mha_combined_qkv,\n", - " \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n", - " \"5) PyTorch MHA class defaults\": mha_pytorch_class_default,\n", - " \"6) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n", - "}\n", - "execution_times = [time_pytorch_function(fn, embeddings_cuda) for name,fn in functions.items()]\n", - "\n", - "\n", - "# Plotting\n", - "\n", - "# Customize further for dark mode aesthetics\n", - "plt.rcParams['figure.facecolor'] = '#121212' # Dark figure background\n", - "plt.rcParams['axes.facecolor'] = '#121212' # Dark axes background\n", - "plt.rcParams['axes.edgecolor'] = 'white' # White axes border\n", - "plt.rcParams['axes.labelcolor'] = 'white' # White labels\n", - "plt.rcParams['text.color'] = 'white' # White text\n", - "plt.rcParams['xtick.color'] = 'white' # White x ticks\n", - "plt.rcParams['ytick.color'] = 'white' # White y ticks\n", - "plt.rcParams['grid.color'] = '#444444' # Lighter grid lines for contrast\n", - "plt.rcParams['lines.linewidth'] = 2 # Thicker plot lines for visibility\n", - "plt.rcParams['lines.markersize'] = 8 # Larger markers for visibility\n", - "\n", - "fig, ax = plt.subplots()\n", - "bars = plt.bar(functions.keys(), execution_times)\n", - "\n", - "plt.ylabel('Execution time (ms)')\n", - "plt.xticks(rotation=45, ha=\"right\")\n", - "\n", - "# Calculate new ylim with a margin\n", - "max_execution_time = max(execution_times)\n", - "upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n", - "\n", - "plt.ylim(0, upper_ylim) # Setting new ylim\n", - "\n", - "# Annotate bars with execution times\n", - "for bar in bars:\n", - " yval = bar.get_height()\n", - " plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha='center', va='bottom')\n", - "\n", - "\n", - "plt.tight_layout()\n", - "plt.savefig(\"1.pdf\")\n", - "plt.show()\n" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "A100", - "machine_shape": "hm", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } + "cells": [ + { + "cell_type": "markdown", + "id": "6f678e62-7bcb-4405-86ae-dce94f494303", + "metadata": { + "id": "6f678e62-7bcb-4405-86ae-dce94f494303" + }, + "source": [ + "# Efficient Multi-Head Attention Implementations" + ] }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + { + "cell_type": "markdown", + "id": "b742938a-4bfc-4527-a1f1-d5963508967d", + "metadata": { + "id": "b742938a-4bfc-4527-a1f1-d5963508967d" + }, + "source": [ + "This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "outputId": "7d088260-3fa1-44f2-bd65-2a46e289f9d4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch version: 2.2.1+cu121\n", + "Running on cuda\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "torch.manual_seed(123)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "print(f\"Running on {device}\")\n", + "\n", + "batch_size = 8\n", + "context_len = 1024\n", + "embed_dim = 768\n", + "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" + ] + }, + { + "cell_type": "markdown", + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", + "metadata": { + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" + }, + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "outputId": "f8a33752-2cd6-4101-8feb-9d1699984719" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", + "\n", + "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim//12,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_ch03_wrapper(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "21930804-b327-40b1-8e63-94dcad39ce7b", + "metadata": { + "id": "21930804-b327-40b1-8e63-94dcad39ce7b" + }, + "source": [ + "## 2) The multi-head attention class from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "outputId": "b704a040-3547-422c-ecda-df9982a2da35" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttention as Ch03_MHA\n", + "\n", + "mha_ch03 = Ch03_MHA(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_ch03(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", + "metadata": { + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" + }, + "source": [ + "## 3) An alternative multi-head attention with combined weights" + ] + }, + { + "cell_type": "markdown", + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", + "metadata": { + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" + }, + "source": [ + "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", + "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", + "\n", + " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + "\n", + "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", + "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "outputId": "5d948671-176f-4633-bede-97767e36becc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MultiHeadAttentionCombinedQKV(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", + " attn_scores = queries @ keys.transpose(-2, -1)\n", + " attn_scores = attn_scores.masked_fill(\n", + " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", + " )\n", + "\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", + " attn_weights = self.dropout(attn_weights)\n", + "\n", + " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", + " context_vec = attn_weights @ values\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", + " context_vec = context_vec.transpose(1, 2)\n", + "\n", + " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", + " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", + "\n", + " context_vec = self.proj(context_vec)\n", + "\n", + " return context_vec\n", + "\n", + "\n", + "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_combined_qkv(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", + "metadata": { + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" + }, + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention" + ] + }, + { + "cell_type": "markdown", + "id": "f78e346f-3b85-44e6-9feb-f01131381148", + "metadata": { + "id": "f78e346f-3b85-44e6-9feb-f01131381148" + }, + "source": [ + "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", + "metadata": { + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" + }, + "outputs": [], + "source": [ + "class MHAPyTorchScaledDotProduct(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + " self.d_out = d_out\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = dropout\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " use_dropout = 0. if not self.training else self.dropout\n", + " context_vec = nn.functional.scaled_dot_product_attention(\n", + " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "outputId": "af9e4855-7f20-4d61-8532-4827df8dfb30" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_scaled(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "351c318f-4835-4d74-8d58-a070222447c4", + "metadata": { + "id": "351c318f-4835-4d74-8d58-a070222447c4" + }, + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention" + ] + }, + { + "cell_type": "markdown", + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", + "metadata": { + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" + }, + "source": [ + "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "outputId": "2a085df8-0445-4818-9978-6dc74469f568" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MHAPyTorchClass(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n", + " super().__init__()\n", + "\n", + " self.block_size = block_size\n", + " self.multihead_attn = nn.MultiheadAttention(\n", + " embed_dim=d_out,\n", + " num_heads=num_heads,\n", + " dropout=dropout,\n", + " bias=qkv_bias,\n", + " add_bias_kv=qkv_bias,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.need_weights = need_weights\n", + " self.proj = nn.Linear(d_out, d_out)\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, _ = x.shape\n", + "\n", + " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", + " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", + " if self.block_size >= num_tokens:\n", + " attn_mask = self.mask[:num_tokens, :num_tokens]\n", + " else:\n", + " attn_mask = self.mask[:self.block_size, :self.block_size]\n", + "\n", + " # attn_mask broadcasting will handle batch_size dimension implicitly\n", + " attn_output, _ = self.multihead_attn(\n", + " x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n", + " )\n", + "\n", + " output = self.proj(attn_output)\n", + "\n", + " return output\n", + "\n", + "\n", + "mha_pytorch_class_default = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class_default(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4", + "metadata": { + "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4" + }, + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`" + ] + }, + { + "cell_type": "markdown", + "id": "d2164859-31a0-4537-b4fb-27d57675ba77", + "metadata": { + "id": "d2164859-31a0-4537-b4fb-27d57675ba77" + }, + "source": [ + "- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n", + "\n", + "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n", + " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n", + " and achieve the best performance for MHA.\n", + " Default: ``True``." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", + "outputId": "234771f4-8a53-4478-8a9b-cf19f79a5e07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch_class_noweights = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False,\n", + " need_weights=False # NEW!\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class_noweights(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "8877de71-f84f-4f6d-bc87-7552013b6301", + "metadata": { + "id": "8877de71-f84f-4f6d-bc87-7552013b6301" + }, + "source": [ + "## Quick speed comparison (M3 Macbook Air CPU)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200 ms ± 5.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3\n", + "%timeit mha_ch03_wrapper(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "198 ms ± 6.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 2) The multi-head attention class from chapter 3\n", + "%timeit mha_ch03(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "236 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 3) An alternative multi-head attention with combined weights\n", + "%timeit mha_combined_qkv(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "outputId": "80c6e314-0771-470e-b090-628984ce2d85" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "71.6 ms ± 3.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention\n", + "%timeit mha_pytorch_scaled(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "217 ms ± 4.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", + "%timeit mha_pytorch_class_default(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", + "metadata": { + "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", + "outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "205 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", + "%timeit mha_pytorch_class_noweights(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "a78ff594-6cc2-496d-a302-789fa104c3c9", + "metadata": { + "id": "a78ff594-6cc2-496d-a302-789fa104c3c9" + }, + "source": [ + "## Quick speed comparison (Nvidia A100 GPU)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0", + "outputId": "e99a17e9-8139-4b04-dac8-fa1dd5027735" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.35 ms ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3\n", + "%timeit mha_ch03_wrapper(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8686dd69-3655-40e4-a57b-a2c55532a010", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8686dd69-3655-40e4-a57b-a2c55532a010", + "outputId": "5553b42c-b709-41a4-8a8b-be36dae408ab" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.59 ms ± 231 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "## 2) The multi-head attention class from chapter 3\n", + "%timeit mha_ch03(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf", + "outputId": "01b0da88-510b-4b21-919a-0a7519a55ed8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7.21 ms ± 716 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "## 3) An alternative multi-head attention with combined weights\n", + "%timeit mha_combined_qkv(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26", + "outputId": "542706db-5041-45ca-f667-9e1bd1c2c7aa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.38 ms ± 362 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" + ] + } + ], + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention\n", + "%timeit mha_pytorch_scaled(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d", + "outputId": "13cfc808-2b11-4041-fe67-e5a63abe4f28" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.67 ms ± 408 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", + "%timeit mha_pytorch_class_default(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8", + "outputId": "c52858e7-999c-4782-adc9-731f8d69dfa6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.54 ms ± 7.17 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", + "%timeit mha_pytorch_class_noweights(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "dabc6575-0316-4640-a729-e616d5c17b73", + "metadata": { + "id": "dabc6575-0316-4640-a729-e616d5c17b73" + }, + "source": [ + "## Speed comparison (Nvidia A100 GPU) with warmup" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000", + "metadata": { + "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000" + }, + "outputs": [], + "source": [ + "# CUDA benchmark code shared by Andrei Aksionov\n", + "# and based on code from\n", + "# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n", + "\n", + "def time_pytorch_function(func, *input, num_repeats = 1_000):\n", + " # CUDA IS ASYNC so can't use python time module\n", + " start = torch.cuda.Event(enable_timing=True)\n", + " end = torch.cuda.Event(enable_timing=True)\n", + "\n", + " # Warmup\n", + " for _ in range(5):\n", + " func(*input)\n", + " torch.cuda.synchronize()\n", + "\n", + " start.record()\n", + " for _ in range(num_repeats):\n", + " func(*input)\n", + " torch.cuda.synchronize()\n", + " end.record()\n", + " torch.cuda.synchronize()\n", + " return start.elapsed_time(end) / num_repeats" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "CDJAPZaszaqx", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 489 + }, + "id": "CDJAPZaszaqx", + "outputId": "f23e9b83-7fd6-4011-9434-0e6934cf762a" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "embeddings_cuda = embeddings.to(torch.device(\"cuda\"))\n", + "\n", + "functions = {\n", + " \"1) MHA wrapper class\": mha_ch03_wrapper,\n", + " \"2) MHA Ch03\": mha_ch03,\n", + " \"3) MHA with combined QKV weights\": mha_combined_qkv,\n", + " \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n", + " \"5) PyTorch MHA class defaults\": mha_pytorch_class_default,\n", + " \"6) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n", + "}\n", + "execution_times = [time_pytorch_function(fn, embeddings_cuda) for name,fn in functions.items()]\n", + "\n", + "\n", + "# Plotting\n", + "\n", + "# Customize further for dark mode aesthetics\n", + "plt.rcParams['figure.facecolor'] = '#121212' # Dark figure background\n", + "plt.rcParams['axes.facecolor'] = '#121212' # Dark axes background\n", + "plt.rcParams['axes.edgecolor'] = 'white' # White axes border\n", + "plt.rcParams['axes.labelcolor'] = 'white' # White labels\n", + "plt.rcParams['text.color'] = 'white' # White text\n", + "plt.rcParams['xtick.color'] = 'white' # White x ticks\n", + "plt.rcParams['ytick.color'] = 'white' # White y ticks\n", + "plt.rcParams['grid.color'] = '#444444' # Lighter grid lines for contrast\n", + "plt.rcParams['lines.linewidth'] = 2 # Thicker plot lines for visibility\n", + "plt.rcParams['lines.markersize'] = 8 # Larger markers for visibility\n", + "\n", + "fig, ax = plt.subplots()\n", + "bars = plt.bar(functions.keys(), execution_times)\n", + "\n", + "plt.ylabel('Execution time (ms)')\n", + "plt.xticks(rotation=45, ha=\"right\")\n", + "\n", + "# Calculate new ylim with a margin\n", + "max_execution_time = max(execution_times)\n", + "upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n", + "\n", + "plt.ylim(0, upper_ylim) # Setting new ylim\n", + "\n", + "# Annotate bars with execution times\n", + "for bar in bars:\n", + " yval = bar.get_height()\n", + " plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha='center', va='bottom')\n", + "\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(\"1.pdf\")\n", + "plt.show()\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}