Update DDP-script.py

Fix for-loop
This commit is contained in:
Sebastian Raschka 2024-03-01 18:31:05 -06:00 committed by GitHub
parent c9dccb0c40
commit c071ea73f9

View File

@ -121,7 +121,7 @@ def main(rank, world_size, num_epochs):
for epoch in range(num_epochs):
model.train()
for features, labels in enumerate(train_loader):
for features, labels in train_loader:
features, labels = features.to(rank), labels.to(rank) # New: use rank
logits = model(features)
@ -175,4 +175,3 @@ if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
# nprocs=world_size spawns one process per GPU