diff --git a/ch04/01_main-chapter-code/ch04.ipynb b/ch04/01_main-chapter-code/ch04.ipynb index 7111d55..696ee2b 100644 --- a/ch04/01_main-chapter-code/ch04.ipynb +++ b/ch04/01_main-chapter-code/ch04.ipynb @@ -1315,7 +1315,7 @@ "outputs": [], "source": [ "def generate_text_simple(model, idx, max_new_tokens, context_size):\n", - " # idx is (B, T) array of indices in the current context\n", + " # idx is (batch, n_tokens) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " \n", " # Crop current context if it exceeds the supported context size\n", @@ -1328,7 +1328,7 @@ " logits = model(idx_cond)\n", " \n", " # Focus only on the last time step\n", - " # (batch, n_token, vocab_size) becomes (batch, vocab_size)\n", + " # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)\n", " logits = logits[:, -1, :] \n", "\n", " # Apply softmax to get probabilities\n",