mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-18 19:18:18 +00:00
Add torchrun bonus code (#524)
This commit is contained in:
parent
83b47adf0d
commit
d16863c7db
220
appendix-A/01_main-chapter-code/DDP-script-torchrun.py
Normal file
220
appendix-A/01_main-chapter-code/DDP-script-torchrun.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
# Appendix A: Introduction to PyTorch (Part 3)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
# NEW imports:
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.distributed import init_process_group, destroy_process_group
|
||||||
|
|
||||||
|
|
||||||
|
# NEW: function to initialize a distributed process group (1 process / GPU)
|
||||||
|
# this allows communication among processes
|
||||||
|
def ddp_setup(rank, world_size):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
rank: a unique process ID
|
||||||
|
world_size: total number of processes in the group
|
||||||
|
"""
|
||||||
|
# Only set MASTER_ADDR and MASTER_PORT if not already defined by torchrun
|
||||||
|
if "MASTER_ADDR" not in os.environ:
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
if "MASTER_PORT" not in os.environ:
|
||||||
|
os.environ["MASTER_PORT"] = "12345"
|
||||||
|
|
||||||
|
# initialize process group
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
# Disable libuv because PyTorch for Windows isn't built with support
|
||||||
|
os.environ["USE_LIBUV"] = "0"
|
||||||
|
# Windows users may have to use "gloo" instead of "nccl" as backend
|
||||||
|
# gloo: Facebook Collective Communication Library
|
||||||
|
init_process_group(backend="gloo", rank=rank, world_size=world_size)
|
||||||
|
else:
|
||||||
|
# nccl: NVIDIA Collective Communication Library
|
||||||
|
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
|
|
||||||
|
class ToyDataset(Dataset):
|
||||||
|
def __init__(self, X, y):
|
||||||
|
self.features = X
|
||||||
|
self.labels = y
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
one_x = self.features[index]
|
||||||
|
one_y = self.labels[index]
|
||||||
|
return one_x, one_y
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.labels.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralNetwork(torch.nn.Module):
|
||||||
|
def __init__(self, num_inputs, num_outputs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
# 1st hidden layer
|
||||||
|
torch.nn.Linear(num_inputs, 30),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
# 2nd hidden layer
|
||||||
|
torch.nn.Linear(30, 20),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
# output layer
|
||||||
|
torch.nn.Linear(20, num_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
logits = self.layers(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset():
|
||||||
|
X_train = torch.tensor([
|
||||||
|
[-1.2, 3.1],
|
||||||
|
[-0.9, 2.9],
|
||||||
|
[-0.5, 2.6],
|
||||||
|
[2.3, -1.1],
|
||||||
|
[2.7, -1.5]
|
||||||
|
])
|
||||||
|
y_train = torch.tensor([0, 0, 0, 1, 1])
|
||||||
|
|
||||||
|
X_test = torch.tensor([
|
||||||
|
[-0.8, 2.8],
|
||||||
|
[2.6, -1.6],
|
||||||
|
])
|
||||||
|
y_test = torch.tensor([0, 1])
|
||||||
|
|
||||||
|
# Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
|
||||||
|
# factor = 4
|
||||||
|
# X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
|
||||||
|
# y_train = y_train.repeat(factor)
|
||||||
|
# X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
|
||||||
|
# y_test = y_test.repeat(factor)
|
||||||
|
|
||||||
|
train_ds = ToyDataset(X_train, y_train)
|
||||||
|
test_ds = ToyDataset(X_test, y_test)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
dataset=train_ds,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False, # NEW: False because of DistributedSampler below
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
# NEW: chunk batches across GPUs without overlapping samples:
|
||||||
|
sampler=DistributedSampler(train_ds) # NEW
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
dataset=test_ds,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
return train_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
# NEW: wrapper
|
||||||
|
def main(rank, world_size, num_epochs):
|
||||||
|
|
||||||
|
ddp_setup(rank, world_size) # NEW: initialize process groups
|
||||||
|
|
||||||
|
train_loader, test_loader = prepare_dataset()
|
||||||
|
model = NeuralNetwork(num_inputs=2, num_outputs=2)
|
||||||
|
model.to(rank)
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
|
||||||
|
|
||||||
|
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
|
||||||
|
# the core model is now accessible as model.module
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# NEW: Set sampler to ensure each epoch has a different shuffle order
|
||||||
|
train_loader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for features, labels in train_loader:
|
||||||
|
|
||||||
|
features, labels = features.to(rank), labels.to(rank) # New: use rank
|
||||||
|
logits = model(features)
|
||||||
|
loss = F.cross_entropy(logits, labels) # Loss function
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
|
||||||
|
f" | Batchsize {labels.shape[0]:03d}"
|
||||||
|
f" | Train/Val Loss: {loss:.2f}")
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
try:
|
||||||
|
train_acc = compute_accuracy(model, train_loader, device=rank)
|
||||||
|
print(f"[GPU{rank}] Training accuracy", train_acc)
|
||||||
|
test_acc = compute_accuracy(model, test_loader, device=rank)
|
||||||
|
print(f"[GPU{rank}] Test accuracy", test_acc)
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# NEW (not in the book):
|
||||||
|
except ZeroDivisionError as e:
|
||||||
|
raise ZeroDivisionError(
|
||||||
|
f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
|
||||||
|
"torchrun --nproc_per_node=2 DDP-script-torchrun.py\n"
|
||||||
|
f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
|
||||||
|
)
|
||||||
|
####################################################
|
||||||
|
|
||||||
|
destroy_process_group() # NEW: cleanly exit distributed mode
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy(model, dataloader, device):
|
||||||
|
model = model.eval()
|
||||||
|
correct = 0.0
|
||||||
|
total_examples = 0
|
||||||
|
|
||||||
|
for idx, (features, labels) in enumerate(dataloader):
|
||||||
|
features, labels = features.to(device), labels.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(features)
|
||||||
|
predictions = torch.argmax(logits, dim=1)
|
||||||
|
compare = labels == predictions
|
||||||
|
correct += torch.sum(compare)
|
||||||
|
total_examples += len(compare)
|
||||||
|
return (correct / total_examples).item()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# NEW: Use environment variables set by torchrun if available, otherwise default to single-process.
|
||||||
|
if "WORLD_SIZE" in os.environ:
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
else:
|
||||||
|
world_size = 1
|
||||||
|
|
||||||
|
if "LOCAL_RANK" in os.environ:
|
||||||
|
rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
elif "RANK" in os.environ:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
|
||||||
|
# Only print on rank 0 to avoid duplicate prints from each GPU process
|
||||||
|
if rank == 0:
|
||||||
|
print("PyTorch version:", torch.__version__)
|
||||||
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
|
print("Number of GPUs available:", torch.cuda.device_count())
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
num_epochs = 3
|
||||||
|
main(rank, world_size, num_epochs)
|
||||||
@ -31,12 +31,11 @@ def ddp_setup(rank, world_size):
|
|||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
# any free port on the machine
|
# any free port on the machine
|
||||||
os.environ["MASTER_PORT"] = "12345"
|
os.environ["MASTER_PORT"] = "12345"
|
||||||
if platform.system() == "Windows":
|
|
||||||
# Disable libuv because PyTorch for Windows isn't built with support
|
|
||||||
os.environ["USE_LIBUV"] = "0"
|
|
||||||
|
|
||||||
# initialize process group
|
# initialize process group
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
|
# Disable libuv because PyTorch for Windows isn't built with support
|
||||||
|
os.environ["USE_LIBUV"] = "0"
|
||||||
# Windows users may have to use "gloo" instead of "nccl" as backend
|
# Windows users may have to use "gloo" instead of "nccl" as backend
|
||||||
# gloo: Facebook Collective Communication Library
|
# gloo: Facebook Collective Communication Library
|
||||||
init_process_group(backend="gloo", rank=rank, world_size=world_size)
|
init_process_group(backend="gloo", rank=rank, world_size=world_size)
|
||||||
@ -99,6 +98,13 @@ def prepare_dataset():
|
|||||||
])
|
])
|
||||||
y_test = torch.tensor([0, 1])
|
y_test = torch.tensor([0, 1])
|
||||||
|
|
||||||
|
# Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
|
||||||
|
# factor = 4
|
||||||
|
# X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
|
||||||
|
# y_train = y_train.repeat(factor)
|
||||||
|
# X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
|
||||||
|
# y_test = y_test.repeat(factor)
|
||||||
|
|
||||||
train_ds = ToyDataset(X_train, y_train)
|
train_ds = ToyDataset(X_train, y_train)
|
||||||
test_ds = ToyDataset(X_test, y_test)
|
test_ds = ToyDataset(X_test, y_test)
|
||||||
|
|
||||||
@ -153,10 +159,22 @@ def main(rank, world_size, num_epochs):
|
|||||||
f" | Train/Val Loss: {loss:.2f}")
|
f" | Train/Val Loss: {loss:.2f}")
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
train_acc = compute_accuracy(model, train_loader, device=rank)
|
|
||||||
print(f"[GPU{rank}] Training accuracy", train_acc)
|
try:
|
||||||
test_acc = compute_accuracy(model, test_loader, device=rank)
|
train_acc = compute_accuracy(model, train_loader, device=rank)
|
||||||
print(f"[GPU{rank}] Test accuracy", test_acc)
|
print(f"[GPU{rank}] Training accuracy", train_acc)
|
||||||
|
test_acc = compute_accuracy(model, test_loader, device=rank)
|
||||||
|
print(f"[GPU{rank}] Test accuracy", test_acc)
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# NEW (not in the book):
|
||||||
|
except ZeroDivisionError as e:
|
||||||
|
raise ZeroDivisionError(
|
||||||
|
f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
|
||||||
|
"CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py\n"
|
||||||
|
f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
|
||||||
|
)
|
||||||
|
####################################################
|
||||||
|
|
||||||
destroy_process_group() # NEW: cleanly exit distributed mode
|
destroy_process_group() # NEW: cleanly exit distributed mode
|
||||||
|
|
||||||
@ -184,7 +202,6 @@ if __name__ == "__main__":
|
|||||||
print("PyTorch version:", torch.__version__)
|
print("PyTorch version:", torch.__version__)
|
||||||
print("CUDA available:", torch.cuda.is_available())
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
print("Number of GPUs available:", torch.cuda.device_count())
|
print("Number of GPUs available:", torch.cuda.device_count())
|
||||||
|
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
|
||||||
# NEW: spawn new processes
|
# NEW: spawn new processes
|
||||||
|
|||||||
12
appendix-A/01_main-chapter-code/README.md
Normal file
12
appendix-A/01_main-chapter-code/README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Appendix A: Introduction to PyTorch
|
||||||
|
|
||||||
|
### Main Chapter Code
|
||||||
|
|
||||||
|
- [code-part1.ipynb](code-part1.ipynb) contains all the section A.1 to A.8 code as it appears in the chapter
|
||||||
|
- [code-part2.ipynb](code-part2.ipynb) contains all the section A.9 GPU code as it appears in the chapter
|
||||||
|
- [DDP-script.py](DDP-script.py) contains the script to demonstrate multi-GPU usage (note that Jupyter Notebooks only support single GPUs, so this is a script, not a notebook). You can run it as `python DDP-script.py`. If your machine has more than 2 GPUs, run it as `CUDA_VISIBLE_DEVIVES=0,1 python DDP-script.py`.
|
||||||
|
- [exercise-solutions.ipynb](exercise-solutions.ipynb) contains the exercise solutions for this chapter
|
||||||
|
|
||||||
|
### Optional Code
|
||||||
|
|
||||||
|
- [DDP-script-torchrun.py](DDP-script-torchrun.py) is an optional version of the `DDP-script.py` script that runs via the PyTorch `torchrun` command instead of spawning and managing multiple processes ourselves via `multiprocessing.spawn`. The `torchrun` command has the advantage of automatically handling distributed initialization, including multi-node coordination, which slightly simplifies the setup process. You can use this script via `torchrun --nproc_per_node=2 DDP-script-torchrun.py`
|
||||||
11
appendix-A/README.md
Normal file
11
appendix-A/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Appendix A: Introduction to PyTorch
|
||||||
|
|
||||||
|
|
||||||
|
## Main Chapter Code
|
||||||
|
|
||||||
|
- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code
|
||||||
|
|
||||||
|
|
||||||
|
## Bonus Materials
|
||||||
|
|
||||||
|
- [02_setup-recommendations](02_setup-recommendations) contains Python installation and setup recommendations.
|
||||||
Loading…
x
Reference in New Issue
Block a user