mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 11:00:55 +00:00
make code more general for larger models
This commit is contained in:
parent
3328b29521
commit
a63b0f626c
@ -1316,7 +1316,8 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Then, we replace the output layer (`model.out_head`), which originally maps the layer inputs to 50,257 dimensions (the size of the vocabulary)\n",
|
||||
"- Since we finetune the model for binary classification (predicting 2 classes, \"spam\" and \"ham\"), we can replace the output layer as shown below, which will be trainable by default"
|
||||
"- Since we finetune the model for binary classification (predicting 2 classes, \"spam\" and \"ham\"), we can replace the output layer as shown below, which will be trainable by default\n",
|
||||
"- Note that we use `BASE_CONFIG[\"emb_dim\"]` (which is equal to 768 in the `\"gpt2-small (124M)\"` model) to keep the code below more general"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1329,7 +1330,7 @@
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"num_classes = 2\n",
|
||||
"model.out_head = torch.nn.Linear(in_features=768, out_features=num_classes)"
|
||||
"model.out_head = torch.nn.Linear(in_features=BASE_CONFIG[\"emb_dim\"], out_features=num_classes)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user