diff --git a/pdelfin/train/utils.py b/pdelfin/train/utils.py index 0636d5d..5ed1b96 100644 --- a/pdelfin/train/utils.py +++ b/pdelfin/train/utils.py @@ -10,6 +10,8 @@ from logging import Logger from tempfile import TemporaryDirectory from typing import Dict, Generator, List, Optional, TypeVar +from functools import partial + import torch import torch.nn.functional as F from accelerate import Accelerator