From 00b8c0a107f4cec0012cf8d1f6b54309846d25dd Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Fri, 20 Jun 2025 01:56:49 +0200 Subject: [PATCH] minor readability improvements (#668) --- .../additional_experiments.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/ch06/02_bonus_additional-experiments/additional_experiments.py b/ch06/02_bonus_additional-experiments/additional_experiments.py index 974f21e..72d0da7 100644 --- a/ch06/02_bonus_additional-experiments/additional_experiments.py +++ b/ch06/02_bonus_additional-experiments/additional_experiments.py @@ -271,12 +271,11 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, mask = input_batch != pad_token_id last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token - with torch.no_grad(): - logits = model(input_batch) # Logits of last output token - # Select the logits corresponding to the last real token of each sequence - batch_size = logits.size(0) - selected_logits = logits[torch.arange(batch_size), last_token_pos] - predicted_labels = torch.argmax(selected_logits, dim=-1) + logits = model(input_batch) # Logits of last output token + # Select the logits corresponding to the last real token of each sequence + batch_size = logits.size(0) + selected_logits = logits[torch.arange(batch_size), last_token_pos] + predicted_labels = torch.argmax(selected_logits, dim=-1) num_examples += predicted_labels.shape[0] correct_predictions += (predicted_labels == target_batch).sum().item() @@ -643,8 +642,6 @@ if __name__ == "__main__": val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) - tokenizer = tiktoken.get_encoding("gpt2") - num_workers = 0 train_loader = DataLoader(