mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-11 15:02:29 +00:00
Update DDP-script.py
Fix for-loop
This commit is contained in:
parent
c9dccb0c40
commit
c071ea73f9
@ -121,7 +121,7 @@ def main(rank, world_size, num_epochs):
|
||||
for epoch in range(num_epochs):
|
||||
|
||||
model.train()
|
||||
for features, labels in enumerate(train_loader):
|
||||
for features, labels in train_loader:
|
||||
|
||||
features, labels = features.to(rank), labels.to(rank) # New: use rank
|
||||
logits = model(features)
|
||||
@ -175,4 +175,3 @@ if __name__ == "__main__":
|
||||
world_size = torch.cuda.device_count()
|
||||
mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
|
||||
# nprocs=world_size spawns one process per GPU
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user