diff --git a/appendix-A/01_main-chapter-code/DDP-script.py b/appendix-A/01_main-chapter-code/DDP-script.py index 09c54b0..d9528c5 100644 --- a/appendix-A/01_main-chapter-code/DDP-script.py +++ b/appendix-A/01_main-chapter-code/DDP-script.py @@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader # NEW imports: import os +import platform import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP @@ -30,11 +31,19 @@ def ddp_setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" # any free port on the machine os.environ["MASTER_PORT"] = "12345" + if platform.system() == "Windows": + # Disable libuv because PyTorch for Windows isn't built with support + os.environ["USE_LIBUV"] = "0" # initialize process group - # Windows users may have to use "gloo" instead of "nccl" as backend - # nccl: NVIDIA Collective Communication Library - init_process_group(backend="nccl", rank=rank, world_size=world_size) + if platform.system() == "Windows": + # Windows users may have to use "gloo" instead of "nccl" as backend + # gloo: Facebook Collective Communication Library + init_process_group(backend="gloo", rank=rank, world_size=world_size) + else: + # nccl: NVIDIA Collective Communication Library + init_process_group(backend="nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank)