mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
minor readability improvements (#668)
This commit is contained in:
parent
15fa6a84f6
commit
00b8c0a107
@ -271,7 +271,6 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
|
||||
mask = input_batch != pad_token_id
|
||||
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_batch) # Logits of last output token
|
||||
# Select the logits corresponding to the last real token of each sequence
|
||||
batch_size = logits.size(0)
|
||||
@ -643,8 +642,6 @@ if __name__ == "__main__":
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
num_workers = 0
|
||||
|
||||
train_loader = DataLoader(
|
||||
|
Loading…
x
Reference in New Issue
Block a user