diff --git a/ch06/02_bonus_additional-experiments/additional_experiments.py b/ch06/02_bonus_additional-experiments/additional_experiments.py index 974f21e..72d0da7 100644 --- a/ch06/02_bonus_additional-experiments/additional_experiments.py +++ b/ch06/02_bonus_additional-experiments/additional_experiments.py @@ -271,12 +271,11 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, mask = input_batch != pad_token_id last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token - with torch.no_grad(): - logits = model(input_batch) # Logits of last output token - # Select the logits corresponding to the last real token of each sequence - batch_size = logits.size(0) - selected_logits = logits[torch.arange(batch_size), last_token_pos] - predicted_labels = torch.argmax(selected_logits, dim=-1) + logits = model(input_batch) # Logits of last output token + # Select the logits corresponding to the last real token of each sequence + batch_size = logits.size(0) + selected_logits = logits[torch.arange(batch_size), last_token_pos] + predicted_labels = torch.argmax(selected_logits, dim=-1) num_examples += predicted_labels.shape[0] correct_predictions += (predicted_labels == target_batch).sum().item() @@ -643,8 +642,6 @@ if __name__ == "__main__": val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) - tokenizer = tiktoken.get_encoding("gpt2") - num_workers = 0 train_loader = DataLoader(