mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-08 05:47:30 +00:00
parent
dd31946b2a
commit
d361cef65f
@ -1542,7 +1542,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"def calc_loss_batch(input_batch, target_batch, model, device):\n",
|
"def calc_loss_batch(input_batch, target_batch, model, device):\n",
|
||||||
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
||||||
" logits = model(input_batch)[:, -1, :] # Logits of last ouput token\n",
|
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
|
||||||
" loss = torch.nn.functional.cross_entropy(logits, target_batch)\n",
|
" loss = torch.nn.functional.cross_entropy(logits, target_batch)\n",
|
||||||
" return loss"
|
" return loss"
|
||||||
]
|
]
|
||||||
@ -1665,7 +1665,7 @@
|
|||||||
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
||||||
" if i < num_batches:\n",
|
" if i < num_batches:\n",
|
||||||
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
||||||
" logits = model(input_batch)[:, -1, :] # Logits of last ouput token\n",
|
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
|
||||||
" predicted_labels = torch.argmax(logits, dim=-1)\n",
|
" predicted_labels = torch.argmax(logits, dim=-1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" num_examples += predicted_labels.shape[0]\n",
|
" num_examples += predicted_labels.shape[0]\n",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user