add forward pass

This commit is contained in:
rasbt 2024-01-31 08:00:19 -06:00
parent fcb13fd636
commit d261abce4c

View File

@ -33,7 +33,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"GPT_CONFIG = {\n", "GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n", " \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 1024, # Context length\n", " \"ctx_len\": 1024, # Context length\n",
" \"emb_dim\": 768, # Embedding dimension\n", " \"emb_dim\": 768, # Embedding dimension\n",
@ -166,7 +166,7 @@
], ],
"source": [ "source": [
"torch.manual_seed(123)\n", "torch.manual_seed(123)\n",
"model = DummyGPTModel(GPT_CONFIG)\n", "model = DummyGPTModel(GPT_CONFIG_124M)\n",
"\n", "\n",
"out = model(batch)\n", "out = model(batch)\n",
"print(\"Output shape:\", out.shape)\n", "print(\"Output shape:\", out.shape)\n",
@ -392,7 +392,7 @@
], ],
"source": [ "source": [
"torch.manual_seed(123)\n", "torch.manual_seed(123)\n",
"model = GPTModel(GPT_CONFIG)\n", "model = GPTModel(GPT_CONFIG_124M)\n",
"\n", "\n",
"out = model(batch)\n", "out = model(batch)\n",
"print(\"Output shape:\", out.shape)\n", "print(\"Output shape:\", out.shape)\n",
@ -454,13 +454,127 @@
"## 4.6 Implementing the forward pass" "## 4.6 Implementing the forward pass"
] ]
}, },
{
"cell_type": "markdown",
"id": "2cf2f7ac-d531-45c0-a556-41b7d13c992e",
"metadata": {},
"source": [
"- The following `generate_text_simple` function implements greedy decoding, which is a simple and fast method to generate text\n",
"- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we don't have to compute the softmax function explicitly)\n",
"- In the next chapter, we will implement a more advanced `generate_text` function"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"id": "07700ec8-32e8-4775-9c13-5c43671d6728", "id": "07700ec8-32e8-4775-9c13-5c43671d6728",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
" # idx is (B, T) array of indices in the current context\n",
" for _ in range(max_new_tokens):\n",
" # Crop index to the last block_size tokens\n",
" idx_cond = idx[:, -context_size:]\n",
"\n",
" # Get the predictions\n",
" with torch.no_grad():\n",
" logits = model(idx_cond)\n",
" \n",
" # Focus only on the last time step\n",
" logits = logits[:, -1, :] # becomes (B, C)\n",
"\n",
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)\n",
"\n",
" # Append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
"\n",
" return idx"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7b1444cd-f7b9-4348-9034-523f7bd20597",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoded: [17250, 11, 314, 1101]\n",
"encoded_tensor.shape: torch.Size([1, 4])\n"
]
}
],
"source": [
"start_context = \"Hi, I'm\"\n",
"\n",
"encoded = tokenizer.encode(start_context)\n",
"print(\"encoded:\", encoded)\n",
"\n",
"encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
"print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1bfb9fc6-dd14-457b-b0ee-cf0e33e2bb33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[17250, 11, 314, 1101, 19106, 34827, 44404, 13427, 42336, 17502,\n",
" 48541, 27970, 23115, 37903, 42378, 47769, 5746, 4830, 22564, 32971,\n",
" 41836, 32124, 36101, 33912, 23027, 32252, 2951, 27247, 10075, 44911,\n",
" 6399, 42214, 3226, 10305, 35112, 9234, 9466, 27622, 44047, 32440,\n",
" 23270, 31510, 15309, 35186]])\n"
]
}
],
"source": [
"model.eval() # disable dropout\n",
"\n",
"out = generate_text_simple(\n",
" model=model,\n",
" idx=encoded_tensor, \n",
" max_new_tokens=40, \n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"]\n",
")\n",
"\n",
"print(out)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4e6d40d2-a542-44da-acdc-703a703d80a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hi, I'mshadow maiden motel acceptanceRN lobbying↑312 Lac Sioux Styarrellà iTunesluxWHAT groomingridges Ved†illon Scy eyes Wong cyber Celebration subsequ dodging OfSD liberated monster */ Ducks contestant competed Partnership 226Alex residue\n"
]
}
],
"source": [
"decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
"print(decoded_text)"
]
},
{
"cell_type": "markdown",
"id": "847888d2-0f92-4840-8a87-8836b529cee5",
"metadata": {},
"source": [
"- Note that the model is untrained hence the random output texts above\n",
"- We will train the model in the next chapter"
]
} }
], ],
"metadata": { "metadata": {