minor readability improvements (#668)

This commit is contained in:
casinca 2025-06-20 01:56:49 +02:00 committed by GitHub
parent 15fa6a84f6
commit 00b8c0a107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -271,12 +271,11 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
mask = input_batch != pad_token_id
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
with torch.no_grad():
logits = model(input_batch) # Logits of last output token
# Select the logits corresponding to the last real token of each sequence
batch_size = logits.size(0)
selected_logits = logits[torch.arange(batch_size), last_token_pos]
predicted_labels = torch.argmax(selected_logits, dim=-1)
logits = model(input_batch) # Logits of last output token
# Select the logits corresponding to the last real token of each sequence
batch_size = logits.size(0)
selected_logits = logits[torch.arange(batch_size), last_token_pos]
predicted_labels = torch.argmax(selected_logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
@ -643,8 +642,6 @@ if __name__ == "__main__":
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
tokenizer = tiktoken.get_encoding("gpt2")
num_workers = 0
train_loader = DataLoader(