Merge pull request #173 from rasbt/device-setting

Fix device setting
This commit is contained in:
Sebastian Raschka 2024-05-22 18:59:58 -04:00 committed by GitHub
commit 9587b58cf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -350,9 +350,7 @@ if __name__ == "__main__":
}
model = GPTModel(BASE_CONFIG)
model.eval()
device = "cpu"
model.to(device)
# Code as it is used in the main chapter
else:
@ -380,10 +378,7 @@ if __name__ == "__main__":
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
########################################
# Modify and pretrained model
@ -396,6 +391,7 @@ if __name__ == "__main__":
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
model.to(device)
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True