Show epochs as integers on x-axis (#241)

* Show epochs as integers on x-axis

* Update ch07/01_main-chapter-code/previous_chapters.py

* remove extra s

* modify exercise plots

* update chapter 7 plot

* resave ch07 for better file diff
This commit is contained in:
Sebastian Raschka 2024-06-23 07:41:25 -05:00 committed by GitHub
parent 36a29e783a
commit def84a039c
5 changed files with 88 additions and 69 deletions

View File

@ -1347,6 +1347,8 @@
],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"\n",
"\n",
"def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n",
" fig, ax1 = plt.subplots(figsize=(5, 3))\n",
@ -1357,6 +1359,7 @@
" ax1.set_xlabel(\"Epochs\")\n",
" ax1.set_ylabel(\"Loss\")\n",
" ax1.legend(loc=\"upper right\")\n",
" ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis\n",
"\n",
" # Create a second x-axis for tokens seen\n",
" ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n",
@ -2455,7 +2458,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
#####################################
# Chapter 2
@ -295,6 +295,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis

File diff suppressed because one or more lines are too long

View File

@ -15,6 +15,7 @@ import time
import urllib
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader
@ -280,6 +281,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis

View File

@ -9,6 +9,7 @@
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import tiktoken
import torch
@ -457,6 +458,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis