diff --git a/appendix-D/01_main-chapter-code/previous_chapters.py b/appendix-D/01_main-chapter-code/previous_chapters.py index 275eac9..4292f64 100644 --- a/appendix-D/01_main-chapter-code/previous_chapters.py +++ b/appendix-D/01_main-chapter-code/previous_chapters.py @@ -255,7 +255,9 @@ def calc_loss_batch(input_batch, target_batch, model, device): def calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. - if num_batches is None: + if len(data_loader) == 0: + return float("nan") + elif num_batches is None: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index afe1300..4b01496 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1090,7 +1090,9 @@ "\n", "def calc_loss_loader(data_loader, model, device, num_batches=None):\n", " total_loss = 0.\n", - " if num_batches is None:\n", + " if len(data_loader) == 0:\n", + " return float(\"nan\")\n", + " elif num_batches is None:\n", " num_batches = len(data_loader)\n", " else:\n", " # Reduce the number of batches to match the total number of batches in the data loader\n", diff --git a/ch05/01_main-chapter-code/gpt_train.py b/ch05/01_main-chapter-code/gpt_train.py index b333291..d025d76 100644 --- a/ch05/01_main-chapter-code/gpt_train.py +++ b/ch05/01_main-chapter-code/gpt_train.py @@ -34,7 +34,9 @@ def calc_loss_batch(input_batch, target_batch, model, device): def calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. - if num_batches is None: + if len(data_loader) == 0: + return float("nan") + elif num_batches is None: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py index b79348d..a5ded8e 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -249,7 +249,9 @@ def calc_loss_batch(input_batch, target_batch, model, device): def calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. - if num_batches is None: + if len(data_loader) == 0: + return float("nan") + elif num_batches is None: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) diff --git a/ch05/05_bonus_hparam_tuning/hparam_search.py b/ch05/05_bonus_hparam_tuning/hparam_search.py index c207357..94cfb0d 100644 --- a/ch05/05_bonus_hparam_tuning/hparam_search.py +++ b/ch05/05_bonus_hparam_tuning/hparam_search.py @@ -26,7 +26,9 @@ HPARAM_GRID = { def calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. - if num_batches is None: + if len(data_loader) == 0: + return float("nan") + elif num_batches is None: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader))