diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index 717af34..1de293c 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -1544,48 +1544,7 @@ "id": "193d3d2b-2578-40ba-b791-ea2d49328e48", "metadata": {}, "source": [ - "- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4.\n", - "\n", - "- If we want to have an output dimension of 2, as earlier in single-head attention, we can have to change the projection dimension `d_out` to 1:" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "dc9a4375-068b-4b2a-aabb-a29347ca5ecd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-0.5740, 0.2216],\n", - " [-0.7320, 0.0155],\n", - " [-0.7774, -0.0546],\n", - " [-0.6979, -0.0817],\n", - " [-0.6538, -0.0957],\n", - " [-0.6424, -0.1065]],\n", - "\n", - " [[-0.5740, 0.2216],\n", - " [-0.7320, 0.0155],\n", - " [-0.7774, -0.0546],\n", - " [-0.6979, -0.0817],\n", - " [-0.6538, -0.0957],\n", - " [-0.6424, -0.1065]]], grad_fn=)\n", - "context_vecs.shape: torch.Size([2, 6, 2])\n" - ] - } - ], - "source": [ - "torch.manual_seed(123)\n", - "\n", - "d_out = 1\n", - "mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n", - "\n", - "context_vecs = mha(batch)\n", - "\n", - "print(context_vecs)\n", - "print(\"context_vecs.shape:\", context_vecs.shape)" + "- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4." ] }, { @@ -1865,7 +1824,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.12" } }, "nbformat": 4,