mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 12:00:23 +00:00
Show epochs as integers on x-axis (#241)
* Show epochs as integers on x-axis * Update ch07/01_main-chapter-code/previous_chapters.py * remove extra s * modify exercise plots * update chapter 7 plot * resave ch07 for better file diff
This commit is contained in:
parent
36a29e783a
commit
def84a039c
@ -1347,6 +1347,8 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from matplotlib.ticker import MaxNLocator\n",
|
||||||
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n",
|
"def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n",
|
||||||
" fig, ax1 = plt.subplots(figsize=(5, 3))\n",
|
" fig, ax1 = plt.subplots(figsize=(5, 3))\n",
|
||||||
@ -1357,6 +1359,7 @@
|
|||||||
" ax1.set_xlabel(\"Epochs\")\n",
|
" ax1.set_xlabel(\"Epochs\")\n",
|
||||||
" ax1.set_ylabel(\"Loss\")\n",
|
" ax1.set_ylabel(\"Loss\")\n",
|
||||||
" ax1.legend(loc=\"upper right\")\n",
|
" ax1.legend(loc=\"upper right\")\n",
|
||||||
|
" ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Create a second x-axis for tokens seen\n",
|
" # Create a second x-axis for tokens seen\n",
|
||||||
" ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n",
|
" ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n",
|
||||||
@ -2455,7 +2458,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -12,7 +12,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
# Chapter 2
|
# Chapter 2
|
||||||
@ -295,6 +295,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
|
|||||||
ax1.set_xlabel("Epochs")
|
ax1.set_xlabel("Epochs")
|
||||||
ax1.set_ylabel("Loss")
|
ax1.set_ylabel("Loss")
|
||||||
ax1.legend(loc="upper right")
|
ax1.legend(loc="upper right")
|
||||||
|
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||||
|
|
||||||
# Create a second x-axis for tokens seen
|
# Create a second x-axis for tokens seen
|
||||||
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
||||||
|
File diff suppressed because one or more lines are too long
@ -15,6 +15,7 @@ import time
|
|||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
@ -280,6 +281,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
|
|||||||
ax1.set_xlabel("Epochs")
|
ax1.set_xlabel("Epochs")
|
||||||
ax1.set_ylabel("Loss")
|
ax1.set_ylabel("Loss")
|
||||||
ax1.legend(loc="upper right")
|
ax1.legend(loc="upper right")
|
||||||
|
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
|
||||||
|
|
||||||
# Create a second x-axis for tokens seen
|
# Create a second x-axis for tokens seen
|
||||||
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
@ -457,6 +458,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
|
|||||||
ax1.set_xlabel("Epochs")
|
ax1.set_xlabel("Epochs")
|
||||||
ax1.set_ylabel("Loss")
|
ax1.set_ylabel("Loss")
|
||||||
ax1.legend(loc="upper right")
|
ax1.legend(loc="upper right")
|
||||||
|
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
|
||||||
|
|
||||||
# Create a second x-axis for tokens seen
|
# Create a second x-axis for tokens seen
|
||||||
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
||||||
|
Loading…
x
Reference in New Issue
Block a user