mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-03 19:30:26 +00:00 
			
		
		
		
	remove unused function from user interface
This commit is contained in:
		
							parent
							
								
									52ee1c7cdb
								
							
						
					
					
						commit
						6f6dfb6796
					
				@ -219,32 +219,6 @@ class GPTModel(nn.Module):
 | 
			
		||||
        return logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
 | 
			
		||||
    # idx is (B, T) array of indices in the current context
 | 
			
		||||
    for _ in range(max_new_tokens):
 | 
			
		||||
 | 
			
		||||
        # Crop current context if it exceeds the supported context size
 | 
			
		||||
        # E.g., if LLM supports only 5 tokens, and the context size is 10
 | 
			
		||||
        # then only the last 5 tokens are used as context
 | 
			
		||||
        idx_cond = idx[:, -context_size:]
 | 
			
		||||
 | 
			
		||||
        # Get the predictions
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            logits = model(idx_cond)
 | 
			
		||||
 | 
			
		||||
        # Focus only on the last time step
 | 
			
		||||
        # (batch, n_token, vocab_size) becomes (batch, vocab_size)
 | 
			
		||||
        logits = logits[:, -1, :]
 | 
			
		||||
 | 
			
		||||
        # Get the idx of the vocab entry with the highest logits value
 | 
			
		||||
        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch, 1)
 | 
			
		||||
 | 
			
		||||
        # Append sampled index to the running sequence
 | 
			
		||||
        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
 | 
			
		||||
 | 
			
		||||
    return idx
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#####################################
 | 
			
		||||
# Chapter 5
 | 
			
		||||
#####################################
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user