From 6f6dfb679685ccf4e13b015fdbcd4d9d4cb6a27b Mon Sep 17 00:00:00 2001 From: rasbt Date: Sat, 21 Sep 2024 14:17:35 -0500 Subject: [PATCH] remove unused function from user interface --- ch05/06_user_interface/previous_chapters.py | 26 --------------------- 1 file changed, 26 deletions(-) diff --git a/ch05/06_user_interface/previous_chapters.py b/ch05/06_user_interface/previous_chapters.py index 892838f..dc026ed 100644 --- a/ch05/06_user_interface/previous_chapters.py +++ b/ch05/06_user_interface/previous_chapters.py @@ -219,32 +219,6 @@ class GPTModel(nn.Module): return logits -def generate_text_simple(model, idx, max_new_tokens, context_size): - # idx is (B, T) array of indices in the current context - for _ in range(max_new_tokens): - - # Crop current context if it exceeds the supported context size - # E.g., if LLM supports only 5 tokens, and the context size is 10 - # then only the last 5 tokens are used as context - idx_cond = idx[:, -context_size:] - - # Get the predictions - with torch.no_grad(): - logits = model(idx_cond) - - # Focus only on the last time step - # (batch, n_token, vocab_size) becomes (batch, vocab_size) - logits = logits[:, -1, :] - - # Get the idx of the vocab entry with the highest logits value - idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1) - - # Append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) - - return idx - - ##################################### # Chapter 5 #####################################