mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	
		
			
	
	
		
			64 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			64 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
							 | 
						||
| 
								 | 
							
								# Source for "Build a Large Language Model From Scratch"
							 | 
						||
| 
								 | 
							
								#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
							 | 
						||
| 
								 | 
							
								# Code: https://github.com/rasbt/LLMs-from-scratch
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								# This file collects all the relevant code that we covered thus far
							 | 
						||
| 
								 | 
							
								# throughout Chapters 2-4.
							 | 
						||
| 
								 | 
							
								# This file can be run as a standalone script.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#####################################
							 | 
						||
| 
								 | 
							
								# Chapter 5
							 | 
						||
| 
								 | 
							
								#####################################
							 | 
						||
| 
								 | 
							
								def text_to_token_ids(text, tokenizer):
							 | 
						||
| 
								 | 
							
								    encoded = tokenizer.encode(text)
							 | 
						||
| 
								 | 
							
								    encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension
							 | 
						||
| 
								 | 
							
								    return encoded_tensor
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def token_ids_to_text(token_ids, tokenizer):
							 | 
						||
| 
								 | 
							
								    flat = token_ids.squeeze(0)  # remove batch dimension
							 | 
						||
| 
								 | 
							
								    return tokenizer.decode(flat.tolist())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # For-loop is the same as before: Get logits, and only focus on last time step
							 | 
						||
| 
								 | 
							
								    for _ in range(max_new_tokens):
							 | 
						||
| 
								 | 
							
								        idx_cond = idx[:, -context_size:]
							 | 
						||
| 
								 | 
							
								        with torch.no_grad():
							 | 
						||
| 
								 | 
							
								            logits = model(idx_cond)
							 | 
						||
| 
								 | 
							
								        logits = logits[:, -1, :]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # New: Filter logits with top_k sampling
							 | 
						||
| 
								 | 
							
								        if top_k is not None:
							 | 
						||
| 
								 | 
							
								            # Keep only top_k values
							 | 
						||
| 
								 | 
							
								            top_logits, _ = torch.topk(logits, top_k)
							 | 
						||
| 
								 | 
							
								            min_val = top_logits[:, -1]
							 | 
						||
| 
								 | 
							
								            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # New: Apply temperature scaling
							 | 
						||
| 
								 | 
							
								        if temperature > 0.0:
							 | 
						||
| 
								 | 
							
								            logits = logits / temperature
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # Apply softmax to get probabilities
							 | 
						||
| 
								 | 
							
								            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # Sample from the distribution
							 | 
						||
| 
								 | 
							
								            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Otherwise same as before: get idx of the vocab entry with the highest logits value
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
							 | 
						||
| 
								 | 
							
								            break
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Same as before: append sampled index to the running sequence
							 | 
						||
| 
								 | 
							
								        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return idx
							 |