Show epochs as integers on x-axis (#241)

* Show epochs as integers on x-axis

* Update ch07/01_main-chapter-code/previous_chapters.py

* remove extra s

* modify exercise plots

* update chapter 7 plot

* resave ch07 for better file diff
This commit is contained in:
Sebastian Raschka 2024-06-23 07:41:25 -05:00 committed by GitHub
parent 36a29e783a
commit def84a039c
5 changed files with 88 additions and 69 deletions

View File

@ -1347,6 +1347,8 @@
], ],
"source": [ "source": [
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"\n",
"\n", "\n",
"def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n", "def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n",
" fig, ax1 = plt.subplots(figsize=(5, 3))\n", " fig, ax1 = plt.subplots(figsize=(5, 3))\n",
@ -1357,6 +1359,7 @@
" ax1.set_xlabel(\"Epochs\")\n", " ax1.set_xlabel(\"Epochs\")\n",
" ax1.set_ylabel(\"Loss\")\n", " ax1.set_ylabel(\"Loss\")\n",
" ax1.legend(loc=\"upper right\")\n", " ax1.legend(loc=\"upper right\")\n",
" ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis\n",
"\n", "\n",
" # Create a second x-axis for tokens seen\n", " # Create a second x-axis for tokens seen\n",
" ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n", " ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n",
@ -2455,7 +2458,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.6" "version": "3.11.4"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -12,7 +12,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
##################################### #####################################
# Chapter 2 # Chapter 2
@ -295,6 +295,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
ax1.set_xlabel("Epochs") ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss") ax1.set_ylabel("Loss")
ax1.legend(loc="upper right") ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
# Create a second x-axis for tokens seen # Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis

File diff suppressed because one or more lines are too long

View File

@ -15,6 +15,7 @@ import time
import urllib import urllib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import tiktoken import tiktoken
import torch import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
@ -280,6 +281,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
ax1.set_xlabel("Epochs") ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss") ax1.set_ylabel("Loss")
ax1.legend(loc="upper right") ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
# Create a second x-axis for tokens seen # Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis

View File

@ -9,6 +9,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np import numpy as np
import tiktoken import tiktoken
import torch import torch
@ -457,6 +458,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
ax1.set_xlabel("Epochs") ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss") ax1.set_ylabel("Loss")
ax1.legend(loc="upper right") ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
# Create a second x-axis for tokens seen # Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis