Update ch06.ipynb (#143)

ouput -> output
This commit is contained in:
Ikko Eltociear Ashimine 2024-05-06 02:18:20 +09:00 committed by GitHub
parent dd31946b2a
commit d361cef65f

View File

@ -1542,7 +1542,7 @@
"source": [ "source": [
"def calc_loss_batch(input_batch, target_batch, model, device):\n", "def calc_loss_batch(input_batch, target_batch, model, device):\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", " input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
" logits = model(input_batch)[:, -1, :] # Logits of last ouput token\n", " logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
" loss = torch.nn.functional.cross_entropy(logits, target_batch)\n", " loss = torch.nn.functional.cross_entropy(logits, target_batch)\n",
" return loss" " return loss"
] ]
@ -1665,7 +1665,7 @@
" for i, (input_batch, target_batch) in enumerate(data_loader):\n", " for i, (input_batch, target_batch) in enumerate(data_loader):\n",
" if i < num_batches:\n", " if i < num_batches:\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", " input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
" logits = model(input_batch)[:, -1, :] # Logits of last ouput token\n", " logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
" predicted_labels = torch.argmax(logits, dim=-1)\n", " predicted_labels = torch.argmax(logits, dim=-1)\n",
"\n", "\n",
" num_examples += predicted_labels.shape[0]\n", " num_examples += predicted_labels.shape[0]\n",