mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-15 04:01:44 +00:00
add forward pass
This commit is contained in:
parent
fcb13fd636
commit
d261abce4c
@ -33,7 +33,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"GPT_CONFIG = {\n",
|
||||
"GPT_CONFIG_124M = {\n",
|
||||
" \"vocab_size\": 50257, # Vocabulary size\n",
|
||||
" \"ctx_len\": 1024, # Context length\n",
|
||||
" \"emb_dim\": 768, # Embedding dimension\n",
|
||||
@ -166,7 +166,7 @@
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"model = DummyGPTModel(GPT_CONFIG)\n",
|
||||
"model = DummyGPTModel(GPT_CONFIG_124M)\n",
|
||||
"\n",
|
||||
"out = model(batch)\n",
|
||||
"print(\"Output shape:\", out.shape)\n",
|
||||
@ -392,7 +392,7 @@
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"model = GPTModel(GPT_CONFIG)\n",
|
||||
"model = GPTModel(GPT_CONFIG_124M)\n",
|
||||
"\n",
|
||||
"out = model(batch)\n",
|
||||
"print(\"Output shape:\", out.shape)\n",
|
||||
@ -454,13 +454,127 @@
|
||||
"## 4.6 Implementing the forward pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cf2f7ac-d531-45c0-a556-41b7d13c992e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- The following `generate_text_simple` function implements greedy decoding, which is a simple and fast method to generate text\n",
|
||||
"- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we don't have to compute the softmax function explicitly)\n",
|
||||
"- In the next chapter, we will implement a more advanced `generate_text` function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"id": "07700ec8-32e8-4775-9c13-5c43671d6728",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
|
||||
" # idx is (B, T) array of indices in the current context\n",
|
||||
" for _ in range(max_new_tokens):\n",
|
||||
" # Crop index to the last block_size tokens\n",
|
||||
" idx_cond = idx[:, -context_size:]\n",
|
||||
"\n",
|
||||
" # Get the predictions\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(idx_cond)\n",
|
||||
" \n",
|
||||
" # Focus only on the last time step\n",
|
||||
" logits = logits[:, -1, :] # becomes (B, C)\n",
|
||||
"\n",
|
||||
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)\n",
|
||||
"\n",
|
||||
" # Append sampled index to the running sequence\n",
|
||||
" idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
|
||||
"\n",
|
||||
" return idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "7b1444cd-f7b9-4348-9034-523f7bd20597",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"encoded: [17250, 11, 314, 1101]\n",
|
||||
"encoded_tensor.shape: torch.Size([1, 4])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"start_context = \"Hi, I'm\"\n",
|
||||
"\n",
|
||||
"encoded = tokenizer.encode(start_context)\n",
|
||||
"print(\"encoded:\", encoded)\n",
|
||||
"\n",
|
||||
"encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
|
||||
"print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "1bfb9fc6-dd14-457b-b0ee-cf0e33e2bb33",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[17250, 11, 314, 1101, 19106, 34827, 44404, 13427, 42336, 17502,\n",
|
||||
" 48541, 27970, 23115, 37903, 42378, 47769, 5746, 4830, 22564, 32971,\n",
|
||||
" 41836, 32124, 36101, 33912, 23027, 32252, 2951, 27247, 10075, 44911,\n",
|
||||
" 6399, 42214, 3226, 10305, 35112, 9234, 9466, 27622, 44047, 32440,\n",
|
||||
" 23270, 31510, 15309, 35186]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.eval() # disable dropout\n",
|
||||
"\n",
|
||||
"out = generate_text_simple(\n",
|
||||
" model=model,\n",
|
||||
" idx=encoded_tensor, \n",
|
||||
" max_new_tokens=40, \n",
|
||||
" context_size=GPT_CONFIG_124M[\"ctx_len\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "4e6d40d2-a542-44da-acdc-703a703d80a8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hi, I'mshadow maiden motel acceptanceRN lobbying↑312 Lac Sioux Styarrellà iTunesluxWHAT groomingridges Ved†illon Scy eyes Wong cyber Celebration subsequ dodging OfSD liberated monster */ Ducks contestant competed Partnership 226Alex residue\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
|
||||
"print(decoded_text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "847888d2-0f92-4840-8a87-8836b529cee5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Note that the model is untrained hence the random output texts above\n",
|
||||
"- We will train the model in the next chapter"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user