Improve DDP on Windows (#376)

* Update DDP-script.py for Windows

* Windows handling

---------

Co-authored-by: Nathan Brown <nathan@nkbrown.us>
This commit is contained in:
Sebastian Raschka 2024-09-29 16:53:48 -05:00 committed by GitHub
parent 58d0ce83a4
commit 505e9a5fa5

View File

@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader
# NEW imports: # NEW imports:
import os import os
import platform
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -30,11 +31,19 @@ def ddp_setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
# any free port on the machine # any free port on the machine
os.environ["MASTER_PORT"] = "12345" os.environ["MASTER_PORT"] = "12345"
if platform.system() == "Windows":
# Disable libuv because PyTorch for Windows isn't built with support
os.environ["USE_LIBUV"] = "0"
# initialize process group # initialize process group
if platform.system() == "Windows":
# Windows users may have to use "gloo" instead of "nccl" as backend # Windows users may have to use "gloo" instead of "nccl" as backend
# gloo: Facebook Collective Communication Library
init_process_group(backend="gloo", rank=rank, world_size=world_size)
else:
# nccl: NVIDIA Collective Communication Library # nccl: NVIDIA Collective Communication Library
init_process_group(backend="nccl", rank=rank, world_size=world_size) init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank) torch.cuda.set_device(rank)