mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 09:50:23 +00:00 
			
		
		
		
	
		
			
	
	
		
			64 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			64 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | ||
|  | # Source for "Build a Large Language Model From Scratch" | ||
|  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | ||
|  | # Code: https://github.com/rasbt/LLMs-from-scratch | ||
|  | # | ||
|  | # This file collects all the relevant code that we covered thus far | ||
|  | # throughout Chapters 2-4. | ||
|  | # This file can be run as a standalone script. | ||
|  | 
 | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | ##################################### | ||
|  | # Chapter 5 | ||
|  | ##################################### | ||
|  | def text_to_token_ids(text, tokenizer): | ||
|  |     encoded = tokenizer.encode(text) | ||
|  |     encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension | ||
|  |     return encoded_tensor | ||
|  | 
 | ||
|  | 
 | ||
|  | def token_ids_to_text(token_ids, tokenizer): | ||
|  |     flat = token_ids.squeeze(0)  # remove batch dimension | ||
|  |     return tokenizer.decode(flat.tolist()) | ||
|  | 
 | ||
|  | 
 | ||
|  | def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): | ||
|  | 
 | ||
|  |     # For-loop is the same as before: Get logits, and only focus on last time step | ||
|  |     for _ in range(max_new_tokens): | ||
|  |         idx_cond = idx[:, -context_size:] | ||
|  |         with torch.no_grad(): | ||
|  |             logits = model(idx_cond) | ||
|  |         logits = logits[:, -1, :] | ||
|  | 
 | ||
|  |         # New: Filter logits with top_k sampling | ||
|  |         if top_k is not None: | ||
|  |             # Keep only top_k values | ||
|  |             top_logits, _ = torch.topk(logits, top_k) | ||
|  |             min_val = top_logits[:, -1] | ||
|  |             logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) | ||
|  | 
 | ||
|  |         # New: Apply temperature scaling | ||
|  |         if temperature > 0.0: | ||
|  |             logits = logits / temperature | ||
|  | 
 | ||
|  |             # Apply softmax to get probabilities | ||
|  |             probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len) | ||
|  | 
 | ||
|  |             # Sample from the distribution | ||
|  |             idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1) | ||
|  | 
 | ||
|  |         # Otherwise same as before: get idx of the vocab entry with the highest logits value | ||
|  |         else: | ||
|  |             idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1) | ||
|  | 
 | ||
|  |         if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified | ||
|  |             break | ||
|  | 
 | ||
|  |         # Same as before: append sampled index to the running sequence | ||
|  |         idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1) | ||
|  | 
 | ||
|  |     return idx |