mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-07-31 12:52:28 +00:00
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
|
# Source for "Build a Large Language Model From Scratch"
|
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
|
#
|
|
# This file collects all the relevant code that we covered thus far
|
|
# throughout Chapters 2-4.
|
|
# This file can be run as a standalone script.
|
|
|
|
import torch
|
|
|
|
|
|
#####################################
|
|
# Chapter 5
|
|
#####################################
|
|
def text_to_token_ids(text, tokenizer):
|
|
encoded = tokenizer.encode(text)
|
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
|
|
return encoded_tensor
|
|
|
|
|
|
def token_ids_to_text(token_ids, tokenizer):
|
|
flat = token_ids.squeeze(0) # remove batch dimension
|
|
return tokenizer.decode(flat.tolist())
|
|
|
|
|
|
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
|
|
|
|
# For-loop is the same as before: Get logits, and only focus on last time step
|
|
for _ in range(max_new_tokens):
|
|
idx_cond = idx[:, -context_size:]
|
|
with torch.no_grad():
|
|
logits = model(idx_cond)
|
|
logits = logits[:, -1, :]
|
|
|
|
# New: Filter logits with top_k sampling
|
|
if top_k is not None:
|
|
# Keep only top_k values
|
|
top_logits, _ = torch.topk(logits, top_k)
|
|
min_val = top_logits[:, -1]
|
|
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
|
|
|
# New: Apply temperature scaling
|
|
if temperature > 0.0:
|
|
logits = logits / temperature
|
|
|
|
# Apply softmax to get probabilities
|
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
|
|
|
# Sample from the distribution
|
|
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
|
|
|
# Otherwise same as before: get idx of the vocab entry with the highest logits value
|
|
else:
|
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
|
|
|
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
|
|
break
|
|
|
|
# Same as before: append sampled index to the running sequence
|
|
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
|
|
|
return idx
|