mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-13 11:12:09 +00:00
commit
9587b58cf7
@ -350,9 +350,7 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
model = GPTModel(BASE_CONFIG)
|
model = GPTModel(BASE_CONFIG)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
# Code as it is used in the main chapter
|
# Code as it is used in the main chapter
|
||||||
else:
|
else:
|
||||||
@ -380,10 +378,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model = GPTModel(BASE_CONFIG)
|
model = GPTModel(BASE_CONFIG)
|
||||||
load_weights_into_gpt(model, params)
|
load_weights_into_gpt(model, params)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
# Modify and pretrained model
|
# Modify and pretrained model
|
||||||
@ -396,6 +391,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
|
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
for param in model.trf_blocks[-1].parameters():
|
for param in model.trf_blocks[-1].parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user