mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-10 09:43:05 +00:00
commit
9587b58cf7
@ -350,9 +350,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
model.eval()
|
||||
|
||||
device = "cpu"
|
||||
model.to(device)
|
||||
|
||||
# Code as it is used in the main chapter
|
||||
else:
|
||||
@ -380,10 +378,7 @@ if __name__ == "__main__":
|
||||
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
load_weights_into_gpt(model, params)
|
||||
model.eval()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
########################################
|
||||
# Modify and pretrained model
|
||||
@ -396,6 +391,7 @@ if __name__ == "__main__":
|
||||
|
||||
num_classes = 2
|
||||
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
|
||||
model.to(device)
|
||||
|
||||
for param in model.trf_blocks[-1].parameters():
|
||||
param.requires_grad = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user