mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-05 11:28:14 +00:00
move ex 3.3 solution outside main chapter
This commit is contained in:
parent
da33ce8054
commit
73822b8bfa
@ -1544,48 +1544,7 @@
|
|||||||
"id": "193d3d2b-2578-40ba-b791-ea2d49328e48",
|
"id": "193d3d2b-2578-40ba-b791-ea2d49328e48",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4.\n",
|
"- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4."
|
||||||
"\n",
|
|
||||||
"- If we want to have an output dimension of 2, as earlier in single-head attention, we can have to change the projection dimension `d_out` to 1:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 36,
|
|
||||||
"id": "dc9a4375-068b-4b2a-aabb-a29347ca5ecd",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"tensor([[[-0.5740, 0.2216],\n",
|
|
||||||
" [-0.7320, 0.0155],\n",
|
|
||||||
" [-0.7774, -0.0546],\n",
|
|
||||||
" [-0.6979, -0.0817],\n",
|
|
||||||
" [-0.6538, -0.0957],\n",
|
|
||||||
" [-0.6424, -0.1065]],\n",
|
|
||||||
"\n",
|
|
||||||
" [[-0.5740, 0.2216],\n",
|
|
||||||
" [-0.7320, 0.0155],\n",
|
|
||||||
" [-0.7774, -0.0546],\n",
|
|
||||||
" [-0.6979, -0.0817],\n",
|
|
||||||
" [-0.6538, -0.0957],\n",
|
|
||||||
" [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)\n",
|
|
||||||
"context_vecs.shape: torch.Size([2, 6, 2])\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"\n",
|
|
||||||
"d_out = 1\n",
|
|
||||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
|
|
||||||
"\n",
|
|
||||||
"context_vecs = mha(batch)\n",
|
|
||||||
"\n",
|
|
||||||
"print(context_vecs)\n",
|
|
||||||
"print(\"context_vecs.shape:\", context_vecs.shape)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1865,7 +1824,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user