mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 09:50:23 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			64 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			64 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | |
| # Source for "Build a Large Language Model From Scratch"
 | |
| #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | |
| # Code: https://github.com/rasbt/LLMs-from-scratch
 | |
| #
 | |
| # This file collects all the relevant code that we covered thus far
 | |
| # throughout Chapters 2-4.
 | |
| # This file can be run as a standalone script.
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| #####################################
 | |
| # Chapter 5
 | |
| #####################################
 | |
| def text_to_token_ids(text, tokenizer):
 | |
|     encoded = tokenizer.encode(text)
 | |
|     encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension
 | |
|     return encoded_tensor
 | |
| 
 | |
| 
 | |
| def token_ids_to_text(token_ids, tokenizer):
 | |
|     flat = token_ids.squeeze(0)  # remove batch dimension
 | |
|     return tokenizer.decode(flat.tolist())
 | |
| 
 | |
| 
 | |
| def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
 | |
| 
 | |
|     # For-loop is the same as before: Get logits, and only focus on last time step
 | |
|     for _ in range(max_new_tokens):
 | |
|         idx_cond = idx[:, -context_size:]
 | |
|         with torch.no_grad():
 | |
|             logits = model(idx_cond)
 | |
|         logits = logits[:, -1, :]
 | |
| 
 | |
|         # New: Filter logits with top_k sampling
 | |
|         if top_k is not None:
 | |
|             # Keep only top_k values
 | |
|             top_logits, _ = torch.topk(logits, top_k)
 | |
|             min_val = top_logits[:, -1]
 | |
|             logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
 | |
| 
 | |
|         # New: Apply temperature scaling
 | |
|         if temperature > 0.0:
 | |
|             logits = logits / temperature
 | |
| 
 | |
|             # Apply softmax to get probabilities
 | |
|             probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)
 | |
| 
 | |
|             # Sample from the distribution
 | |
|             idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
 | |
| 
 | |
|         # Otherwise same as before: get idx of the vocab entry with the highest logits value
 | |
|         else:
 | |
|             idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)
 | |
| 
 | |
|         if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
 | |
|             break
 | |
| 
 | |
|         # Same as before: append sampled index to the running sequence
 | |
|         idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)
 | |
| 
 | |
|     return idx
 | 
