Return nan if val loader is empty (#124)

This commit is contained in:
Sebastian Raschka 2024-04-20 08:02:30 -05:00 committed by GitHub
parent 7740d556a0
commit c70ddff558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 15 additions and 5 deletions

View File

@ -255,7 +255,9 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))

View File

@ -1090,7 +1090,9 @@
"\n",
"def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
" total_loss = 0.\n",
" if num_batches is None:\n",
" if len(data_loader) == 0:\n",
" return float(\"nan\")\n",
" elif num_batches is None:\n",
" num_batches = len(data_loader)\n",
" else:\n",
" # Reduce the number of batches to match the total number of batches in the data loader\n",

View File

@ -34,7 +34,9 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))

View File

@ -249,7 +249,9 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))

View File

@ -26,7 +26,9 @@ HPARAM_GRID = {
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))