From d16863c7db81fb3de1131e4e9d4b84c32c450c5f Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 11 Feb 2025 17:01:09 -0600 Subject: [PATCH] Add torchrun bonus code (#524) --- .../DDP-script-torchrun.py | 220 ++++++++++++++++++ appendix-A/01_main-chapter-code/DDP-script.py | 33 ++- appendix-A/01_main-chapter-code/README.md | 12 + appendix-A/README.md | 11 + 4 files changed, 268 insertions(+), 8 deletions(-) create mode 100644 appendix-A/01_main-chapter-code/DDP-script-torchrun.py create mode 100644 appendix-A/01_main-chapter-code/README.md create mode 100644 appendix-A/README.md diff --git a/appendix-A/01_main-chapter-code/DDP-script-torchrun.py b/appendix-A/01_main-chapter-code/DDP-script-torchrun.py new file mode 100644 index 0000000..8fe44b1 --- /dev/null +++ b/appendix-A/01_main-chapter-code/DDP-script-torchrun.py @@ -0,0 +1,220 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# Appendix A: Introduction to PyTorch (Part 3) + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader + +# NEW imports: +import os +import platform +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group + + +# NEW: function to initialize a distributed process group (1 process / GPU) +# this allows communication among processes +def ddp_setup(rank, world_size): + """ + Arguments: + rank: a unique process ID + world_size: total number of processes in the group + """ + # Only set MASTER_ADDR and MASTER_PORT if not already defined by torchrun + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "12345" + + # initialize process group + if platform.system() == "Windows": + # Disable libuv because PyTorch for Windows isn't built with support + os.environ["USE_LIBUV"] = "0" + # Windows users may have to use "gloo" instead of "nccl" as backend + # gloo: Facebook Collective Communication Library + init_process_group(backend="gloo", rank=rank, world_size=world_size) + else: + # nccl: NVIDIA Collective Communication Library + init_process_group(backend="nccl", rank=rank, world_size=world_size) + + torch.cuda.set_device(rank) + + +class ToyDataset(Dataset): + def __init__(self, X, y): + self.features = X + self.labels = y + + def __getitem__(self, index): + one_x = self.features[index] + one_y = self.labels[index] + return one_x, one_y + + def __len__(self): + return self.labels.shape[0] + + +class NeuralNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + + self.layers = torch.nn.Sequential( + # 1st hidden layer + torch.nn.Linear(num_inputs, 30), + torch.nn.ReLU(), + + # 2nd hidden layer + torch.nn.Linear(30, 20), + torch.nn.ReLU(), + + # output layer + torch.nn.Linear(20, num_outputs), + ) + + def forward(self, x): + logits = self.layers(x) + return logits + + +def prepare_dataset(): + X_train = torch.tensor([ + [-1.2, 3.1], + [-0.9, 2.9], + [-0.5, 2.6], + [2.3, -1.1], + [2.7, -1.5] + ]) + y_train = torch.tensor([0, 0, 0, 1, 1]) + + X_test = torch.tensor([ + [-0.8, 2.8], + [2.6, -1.6], + ]) + y_test = torch.tensor([0, 1]) + + # Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs: + # factor = 4 + # X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)]) + # y_train = y_train.repeat(factor) + # X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)]) + # y_test = y_test.repeat(factor) + + train_ds = ToyDataset(X_train, y_train) + test_ds = ToyDataset(X_test, y_test) + + train_loader = DataLoader( + dataset=train_ds, + batch_size=2, + shuffle=False, # NEW: False because of DistributedSampler below + pin_memory=True, + drop_last=True, + # NEW: chunk batches across GPUs without overlapping samples: + sampler=DistributedSampler(train_ds) # NEW + ) + test_loader = DataLoader( + dataset=test_ds, + batch_size=2, + shuffle=False, + ) + return train_loader, test_loader + + +# NEW: wrapper +def main(rank, world_size, num_epochs): + + ddp_setup(rank, world_size) # NEW: initialize process groups + + train_loader, test_loader = prepare_dataset() + model = NeuralNetwork(num_inputs=2, num_outputs=2) + model.to(rank) + optimizer = torch.optim.SGD(model.parameters(), lr=0.5) + + model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP + # the core model is now accessible as model.module + + for epoch in range(num_epochs): + # NEW: Set sampler to ensure each epoch has a different shuffle order + train_loader.sampler.set_epoch(epoch) + + model.train() + for features, labels in train_loader: + + features, labels = features.to(rank), labels.to(rank) # New: use rank + logits = model(features) + loss = F.cross_entropy(logits, labels) # Loss function + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # LOGGING + print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}" + f" | Batchsize {labels.shape[0]:03d}" + f" | Train/Val Loss: {loss:.2f}") + + model.eval() + + try: + train_acc = compute_accuracy(model, train_loader, device=rank) + print(f"[GPU{rank}] Training accuracy", train_acc) + test_acc = compute_accuracy(model, test_loader, device=rank) + print(f"[GPU{rank}] Test accuracy", test_acc) + + #################################################### + # NEW (not in the book): + except ZeroDivisionError as e: + raise ZeroDivisionError( + f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n" + "torchrun --nproc_per_node=2 DDP-script-torchrun.py\n" + f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107." + ) + #################################################### + + destroy_process_group() # NEW: cleanly exit distributed mode + + +def compute_accuracy(model, dataloader, device): + model = model.eval() + correct = 0.0 + total_examples = 0 + + for idx, (features, labels) in enumerate(dataloader): + features, labels = features.to(device), labels.to(device) + + with torch.no_grad(): + logits = model(features) + predictions = torch.argmax(logits, dim=1) + compare = labels == predictions + correct += torch.sum(compare) + total_examples += len(compare) + return (correct / total_examples).item() + + +if __name__ == "__main__": + # NEW: Use environment variables set by torchrun if available, otherwise default to single-process. + if "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + else: + world_size = 1 + + if "LOCAL_RANK" in os.environ: + rank = int(os.environ["LOCAL_RANK"]) + elif "RANK" in os.environ: + rank = int(os.environ["RANK"]) + else: + rank = 0 + + # Only print on rank 0 to avoid duplicate prints from each GPU process + if rank == 0: + print("PyTorch version:", torch.__version__) + print("CUDA available:", torch.cuda.is_available()) + print("Number of GPUs available:", torch.cuda.device_count()) + + torch.manual_seed(123) + num_epochs = 3 + main(rank, world_size, num_epochs) diff --git a/appendix-A/01_main-chapter-code/DDP-script.py b/appendix-A/01_main-chapter-code/DDP-script.py index 557c7b4..9350466 100644 --- a/appendix-A/01_main-chapter-code/DDP-script.py +++ b/appendix-A/01_main-chapter-code/DDP-script.py @@ -31,12 +31,11 @@ def ddp_setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" # any free port on the machine os.environ["MASTER_PORT"] = "12345" - if platform.system() == "Windows": - # Disable libuv because PyTorch for Windows isn't built with support - os.environ["USE_LIBUV"] = "0" # initialize process group if platform.system() == "Windows": + # Disable libuv because PyTorch for Windows isn't built with support + os.environ["USE_LIBUV"] = "0" # Windows users may have to use "gloo" instead of "nccl" as backend # gloo: Facebook Collective Communication Library init_process_group(backend="gloo", rank=rank, world_size=world_size) @@ -99,6 +98,13 @@ def prepare_dataset(): ]) y_test = torch.tensor([0, 1]) + # Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs: + # factor = 4 + # X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)]) + # y_train = y_train.repeat(factor) + # X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)]) + # y_test = y_test.repeat(factor) + train_ds = ToyDataset(X_train, y_train) test_ds = ToyDataset(X_test, y_test) @@ -153,10 +159,22 @@ def main(rank, world_size, num_epochs): f" | Train/Val Loss: {loss:.2f}") model.eval() - train_acc = compute_accuracy(model, train_loader, device=rank) - print(f"[GPU{rank}] Training accuracy", train_acc) - test_acc = compute_accuracy(model, test_loader, device=rank) - print(f"[GPU{rank}] Test accuracy", test_acc) + + try: + train_acc = compute_accuracy(model, train_loader, device=rank) + print(f"[GPU{rank}] Training accuracy", train_acc) + test_acc = compute_accuracy(model, test_loader, device=rank) + print(f"[GPU{rank}] Test accuracy", test_acc) + + #################################################### + # NEW (not in the book): + except ZeroDivisionError as e: + raise ZeroDivisionError( + f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n" + "CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py\n" + f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107." + ) + #################################################### destroy_process_group() # NEW: cleanly exit distributed mode @@ -184,7 +202,6 @@ if __name__ == "__main__": print("PyTorch version:", torch.__version__) print("CUDA available:", torch.cuda.is_available()) print("Number of GPUs available:", torch.cuda.device_count()) - torch.manual_seed(123) # NEW: spawn new processes diff --git a/appendix-A/01_main-chapter-code/README.md b/appendix-A/01_main-chapter-code/README.md new file mode 100644 index 0000000..2a6787e --- /dev/null +++ b/appendix-A/01_main-chapter-code/README.md @@ -0,0 +1,12 @@ +# Appendix A: Introduction to PyTorch + +### Main Chapter Code + +- [code-part1.ipynb](code-part1.ipynb) contains all the section A.1 to A.8 code as it appears in the chapter +- [code-part2.ipynb](code-part2.ipynb) contains all the section A.9 GPU code as it appears in the chapter +- [DDP-script.py](DDP-script.py) contains the script to demonstrate multi-GPU usage (note that Jupyter Notebooks only support single GPUs, so this is a script, not a notebook). You can run it as `python DDP-script.py`. If your machine has more than 2 GPUs, run it as `CUDA_VISIBLE_DEVIVES=0,1 python DDP-script.py`. +- [exercise-solutions.ipynb](exercise-solutions.ipynb) contains the exercise solutions for this chapter + +### Optional Code + +- [DDP-script-torchrun.py](DDP-script-torchrun.py) is an optional version of the `DDP-script.py` script that runs via the PyTorch `torchrun` command instead of spawning and managing multiple processes ourselves via `multiprocessing.spawn`. The `torchrun` command has the advantage of automatically handling distributed initialization, including multi-node coordination, which slightly simplifies the setup process. You can use this script via `torchrun --nproc_per_node=2 DDP-script-torchrun.py` diff --git a/appendix-A/README.md b/appendix-A/README.md new file mode 100644 index 0000000..75db721 --- /dev/null +++ b/appendix-A/README.md @@ -0,0 +1,11 @@ +# Appendix A: Introduction to PyTorch + +  +## Main Chapter Code + +- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code + +  +## Bonus Materials + +- [02_setup-recommendations](02_setup-recommendations) contains Python installation and setup recommendations. \ No newline at end of file