Improve DDP on Windows (#376)

* Update DDP-script.py for Windows

* Windows handling

---------

Co-authored-by: Nathan Brown <nathan@nkbrown.us>
This commit is contained in:
Sebastian Raschka 2024-09-29 16:53:48 -05:00 committed by GitHub
parent 58d0ce83a4
commit 505e9a5fa5

View File

@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader
# NEW imports:
import os
import platform
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -30,11 +31,19 @@ def ddp_setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
# any free port on the machine
os.environ["MASTER_PORT"] = "12345"
if platform.system() == "Windows":
# Disable libuv because PyTorch for Windows isn't built with support
os.environ["USE_LIBUV"] = "0"
# initialize process group
# Windows users may have to use "gloo" instead of "nccl" as backend
# nccl: NVIDIA Collective Communication Library
init_process_group(backend="nccl", rank=rank, world_size=world_size)
if platform.system() == "Windows":
# Windows users may have to use "gloo" instead of "nccl" as backend
# gloo: Facebook Collective Communication Library
init_process_group(backend="gloo", rank=rank, world_size=world_size)
else:
# nccl: NVIDIA Collective Communication Library
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)