mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
minor readability improvements (#668)
This commit is contained in:
parent
15fa6a84f6
commit
00b8c0a107
@ -271,12 +271,11 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
|
|||||||
mask = input_batch != pad_token_id
|
mask = input_batch != pad_token_id
|
||||||
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
||||||
|
|
||||||
with torch.no_grad():
|
logits = model(input_batch) # Logits of last output token
|
||||||
logits = model(input_batch) # Logits of last output token
|
# Select the logits corresponding to the last real token of each sequence
|
||||||
# Select the logits corresponding to the last real token of each sequence
|
batch_size = logits.size(0)
|
||||||
batch_size = logits.size(0)
|
selected_logits = logits[torch.arange(batch_size), last_token_pos]
|
||||||
selected_logits = logits[torch.arange(batch_size), last_token_pos]
|
predicted_labels = torch.argmax(selected_logits, dim=-1)
|
||||||
predicted_labels = torch.argmax(selected_logits, dim=-1)
|
|
||||||
|
|
||||||
num_examples += predicted_labels.shape[0]
|
num_examples += predicted_labels.shape[0]
|
||||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||||
@ -643,8 +642,6 @@ if __name__ == "__main__":
|
|||||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||||
|
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
|
||||||
|
|
||||||
num_workers = 0
|
num_workers = 0
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user