mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-26 23:39:53 +00:00
Improve DDP on Windows (#376)
* Update DDP-script.py for Windows * Windows handling --------- Co-authored-by: Nathan Brown <nathan@nkbrown.us>
This commit is contained in:
parent
58d0ce83a4
commit
505e9a5fa5
@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
# NEW imports:
|
||||
import os
|
||||
import platform
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -30,11 +31,19 @@ def ddp_setup(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
# any free port on the machine
|
||||
os.environ["MASTER_PORT"] = "12345"
|
||||
if platform.system() == "Windows":
|
||||
# Disable libuv because PyTorch for Windows isn't built with support
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
|
||||
# initialize process group
|
||||
# Windows users may have to use "gloo" instead of "nccl" as backend
|
||||
# nccl: NVIDIA Collective Communication Library
|
||||
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
if platform.system() == "Windows":
|
||||
# Windows users may have to use "gloo" instead of "nccl" as backend
|
||||
# gloo: Facebook Collective Communication Library
|
||||
init_process_group(backend="gloo", rank=rank, world_size=world_size)
|
||||
else:
|
||||
# nccl: NVIDIA Collective Communication Library
|
||||
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user