mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
Add mean pooling experiment to classifier bonus experiments (#406)
* Add mean pooling experiment to classifier bonus experiments * formatting * add average embeddings option * pep8
This commit is contained in:
parent
c4bac22bff
commit
3c3dae0967
@ -28,6 +28,7 @@ For example,
|
||||
| 15 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
|
||||
| 16 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
|
||||
| 17 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
|
||||
| 18 | gpt2-small (124M) | pretrained | last + pooled embeddings | last_block | longest train ex. (120) | 97.79% | 99.33% | 96.33% | 0.32 min | A100 |
|
||||
|
||||
|
||||
|
||||
@ -52,6 +53,7 @@ You can use the following code to reproduce the experiments:
|
||||
- Row 15: `python additional_experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
|
||||
- Row 16: `python additional_experiments.py --disable_causal_mask`
|
||||
- Row 17: `python additional_experiments.py --ignore_index 50256`
|
||||
- Row 18: `python additional_experiments.py --average embeddings`
|
||||
|
||||
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes (for the default setting) in case you don't have access to a GPU.
|
||||
|
||||
@ -70,3 +72,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
|
||||
9. **Padding vs no padding (Row 1 vs. 14 and 15)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
|
||||
10. **Disabling the causal attention mask (Row 1 vs. 16)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
|
||||
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 17)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
|
||||
13. **Averaging the embeddings over all tokens (Row 1 vs. 18)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
|
@ -181,15 +181,24 @@ def instantiate_model(choose_model, load_weights):
|
||||
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=-1, ignore_index=-100):
|
||||
trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
|
||||
return loss
|
||||
|
||||
|
||||
def calc_loss_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1, ignore_index=-100):
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
ignore_index=-100, average_embeddings=False):
|
||||
total_loss = 0.
|
||||
if len(data_loader) == 0:
|
||||
return float("nan")
|
||||
@ -203,7 +212,8 @@ def calc_loss_loader(data_loader, model, device,
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
|
||||
average_embeddings=average_embeddings
|
||||
)
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
@ -212,7 +222,8 @@ def calc_loss_loader(data_loader, model, device,
|
||||
|
||||
|
||||
@torch.no_grad() # Disable gradient tracking for efficiency
|
||||
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token_pos=-1):
|
||||
def calc_accuracy_loader(data_loader, model, device, num_batches=None,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
model.eval()
|
||||
correct_predictions, num_examples = 0, 0
|
||||
|
||||
@ -223,7 +234,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
@ -234,16 +253,19 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
||||
|
||||
|
||||
def evaluate_model(model, train_loader, val_loader, device,
|
||||
eval_iter, trainable_token_pos=-1, ignore_index=-100):
|
||||
eval_iter, trainable_token_pos=-1,
|
||||
ignore_index=-100, average_embeddings=False):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
train_loss = calc_loss_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
|
||||
average_embeddings=average_embeddings
|
||||
)
|
||||
val_loss = calc_loss_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
|
||||
average_embeddings=average_embeddings
|
||||
)
|
||||
model.train()
|
||||
return train_loss, val_loss
|
||||
@ -251,7 +273,7 @@ def evaluate_model(model, train_loader, val_loader, device,
|
||||
|
||||
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
|
||||
accumulation_steps=1, ignore_index=-100):
|
||||
accumulation_steps=1, ignore_index=-100, average_embeddings=False):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
||||
examples_seen, global_step = 0, -1
|
||||
@ -263,7 +285,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
|
||||
average_embeddings=average_embeddings
|
||||
)
|
||||
|
||||
# Use gradient accumulation if accumulation_steps > 1
|
||||
@ -286,7 +309,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
if global_step % eval_freq == 0:
|
||||
train_loss, val_loss = evaluate_model(
|
||||
model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
|
||||
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
|
||||
average_embeddings=average_embeddings
|
||||
)
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
@ -297,8 +321,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
break
|
||||
|
||||
# New: Calculate accuracy after each epoch
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
train_accs.append(train_accuracy)
|
||||
@ -359,13 +389,22 @@ if __name__ == "__main__":
|
||||
"Which token position to train. Options: 'first', 'last'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average_embeddings",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=(
|
||||
"Average the output embeddings from all tokens instead of using"
|
||||
" only the embedding at the token position specified by `--trainable_token_pos`."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_length",
|
||||
type=str,
|
||||
default="longest_training_example",
|
||||
help=(
|
||||
"The context length of the data inputs."
|
||||
"Options: 'longest_training_example', 'model_context_length' or integer value."
|
||||
" Options: 'longest_training_example', 'model_context_length' or integer value."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -409,7 +448,6 @@ if __name__ == "__main__":
|
||||
"The batch size used for training."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--accumulation_steps",
|
||||
type=int,
|
||||
@ -422,7 +460,6 @@ if __name__ == "__main__":
|
||||
" the latter setting uses more iterations."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_causal_mask",
|
||||
action='store_true',
|
||||
@ -431,7 +468,6 @@ if __name__ == "__main__":
|
||||
"Disables the causal attention mask."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ignore_index",
|
||||
type=int,
|
||||
@ -589,7 +625,7 @@ if __name__ == "__main__":
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
|
||||
max_steps=None, trainable_token_pos=args.trainable_token_pos,
|
||||
accumulation_steps=args.accumulation_steps
|
||||
accumulation_steps=args.accumulation_steps, average_embeddings=args.average_embeddings
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@ -600,9 +636,18 @@ if __name__ == "__main__":
|
||||
# Evaluate model
|
||||
###############################
|
||||
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token_pos=args.trainable_token_pos)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token_pos=args.trainable_token_pos)
|
||||
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token_pos=args.trainable_token_pos)
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
test_accuracy = calc_accuracy_loader(
|
||||
test_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}%")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
|
@ -81,14 +81,25 @@ def instantiate_model(choose_model, load_weights):
|
||||
return model
|
||||
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
|
||||
def calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
||||
return loss
|
||||
|
||||
|
||||
def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
|
||||
def calc_loss_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
total_loss = 0.
|
||||
if len(data_loader) == 0:
|
||||
return float("nan")
|
||||
@ -100,7 +111,10 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
break
|
||||
@ -108,7 +122,9 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok
|
||||
|
||||
|
||||
@torch.no_grad() # Disable gradient tracking for efficiency
|
||||
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
|
||||
def calc_accuracy_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
model.eval()
|
||||
correct_predictions, num_examples = 0, 0
|
||||
|
||||
@ -119,7 +135,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
@ -129,17 +153,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
||||
return correct_predictions / num_examples
|
||||
|
||||
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
||||
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
||||
train_loss = calc_loss_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_loss = calc_loss_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
model.train()
|
||||
return train_loss, val_loss
|
||||
|
||||
|
||||
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, max_steps=None, trainable_token=-1):
|
||||
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
||||
examples_seen, global_step = 0, -1
|
||||
@ -150,7 +182,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
|
||||
for input_batch, target_batch in train_loader:
|
||||
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings)
|
||||
loss.backward() # Calculate loss gradients
|
||||
optimizer.step() # Update model weights using loss gradients
|
||||
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
|
||||
@ -159,7 +192,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
# Optional evaluation step
|
||||
if global_step % eval_freq == 0:
|
||||
train_loss, val_loss = evaluate_model(
|
||||
model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
|
||||
model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
print(f"Ep {epoch+1} (Step {global_step:06d}): "
|
||||
@ -169,8 +204,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
break
|
||||
|
||||
# New: Calculate accuracy after each epoch
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
train_accs.append(train_accuracy)
|
||||
@ -211,13 +252,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_token",
|
||||
"--trainable_token_pos",
|
||||
type=str,
|
||||
default="last",
|
||||
help=(
|
||||
"Which token to train. Options: 'first', 'last'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average_embeddings",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=(
|
||||
"Average the output embeddings from all tokens instead of using"
|
||||
" only the embedding at the token position specified by `--trainable_token_pos`."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_length",
|
||||
type=str,
|
||||
@ -245,12 +295,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token == "first":
|
||||
args.trainable_token = 0
|
||||
elif args.trainable_token == "last":
|
||||
args.trainable_token = -1
|
||||
if args.trainable_token_pos == "first":
|
||||
args.trainable_token_pos = 0
|
||||
elif args.trainable_token_pos == "last":
|
||||
args.trainable_token_pos = -1
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_token argument")
|
||||
raise ValueError("Invalid --trainable_token_pos argument")
|
||||
|
||||
###############################
|
||||
# Load model
|
||||
@ -358,7 +408,8 @@ if __name__ == "__main__":
|
||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=args.num_epochs, eval_freq=50, eval_iter=20,
|
||||
max_steps=None, trainable_token=args.trainable_token
|
||||
max_steps=None, trainable_token_pos=args.trainable_token_pos,
|
||||
average_embeddings=args.average_embeddings
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@ -371,9 +422,18 @@ if __name__ == "__main__":
|
||||
|
||||
print("\nEvaluating on the full datasets ...\n")
|
||||
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
|
||||
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
test_accuracy = calc_accuracy_loader(
|
||||
test_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}%")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
|
Loading…
x
Reference in New Issue
Block a user