mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			309 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			309 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | 
						|
# Source for "Build a Large Language Model From Scratch"
 | 
						|
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | 
						|
# Code: https://github.com/rasbt/LLMs-from-scratch
 | 
						|
 | 
						|
import json
 | 
						|
import numpy as np
 | 
						|
import os
 | 
						|
import urllib.request
 | 
						|
 | 
						|
# import requests
 | 
						|
import tensorflow as tf
 | 
						|
import tiktoken
 | 
						|
import torch
 | 
						|
from tqdm import tqdm
 | 
						|
 | 
						|
# Import from local files
 | 
						|
from previous_chapters import GPTModel
 | 
						|
 | 
						|
 | 
						|
def text_to_token_ids(text, tokenizer):
 | 
						|
    encoded = tokenizer.encode(text)
 | 
						|
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension
 | 
						|
    return encoded_tensor
 | 
						|
 | 
						|
 | 
						|
def token_ids_to_text(token_ids, tokenizer):
 | 
						|
    flat = token_ids.squeeze(0)  # remove batch dimension
 | 
						|
    return tokenizer.decode(flat.tolist())
 | 
						|
 | 
						|
 | 
						|
def download_and_load_gpt2(model_size, models_dir):
 | 
						|
    # Validate model size
 | 
						|
    allowed_sizes = ("124M", "355M", "774M", "1558M")
 | 
						|
    if model_size not in allowed_sizes:
 | 
						|
        raise ValueError(f"Model size not in {allowed_sizes}")
 | 
						|
 | 
						|
    # Define paths
 | 
						|
    model_dir = os.path.join(models_dir, model_size)
 | 
						|
    base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
 | 
						|
    filenames = [
 | 
						|
        "checkpoint", "encoder.json", "hparams.json",
 | 
						|
        "model.ckpt.data-00000-of-00001", "model.ckpt.index",
 | 
						|
        "model.ckpt.meta", "vocab.bpe"
 | 
						|
    ]
 | 
						|
 | 
						|
    # Download files
 | 
						|
    os.makedirs(model_dir, exist_ok=True)
 | 
						|
    for filename in filenames:
 | 
						|
        file_url = os.path.join(base_url, model_size, filename)
 | 
						|
        file_path = os.path.join(model_dir, filename)
 | 
						|
        download_file(file_url, file_path)
 | 
						|
 | 
						|
    # Load settings and params
 | 
						|
    tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
 | 
						|
    settings = json.load(open(os.path.join(model_dir, "hparams.json")))
 | 
						|
    params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
 | 
						|
 | 
						|
    return settings, params
 | 
						|
 | 
						|
 | 
						|
"""
 | 
						|
def download_file(url, destination):
 | 
						|
    # Send a GET request to download the file in streaming mode
 | 
						|
    response = requests.get(url, stream=True)
 | 
						|
 | 
						|
    # Get the total file size from headers, defaulting to 0 if not present
 | 
						|
    file_size = int(response.headers.get("content-length", 0))
 | 
						|
 | 
						|
    # Check if file exists and has the same size
 | 
						|
    if os.path.exists(destination):
 | 
						|
        file_size_local = os.path.getsize(destination)
 | 
						|
        if file_size == file_size_local:
 | 
						|
            print(f"File already exists and is up-to-date: {destination}")
 | 
						|
            return
 | 
						|
 | 
						|
    # Define the block size for reading the file
 | 
						|
    block_size = 1024  # 1 Kilobyte
 | 
						|
 | 
						|
    # Initialize the progress bar with total file size
 | 
						|
    progress_bar_description = url.split("/")[-1]  # Extract filename from URL
 | 
						|
    with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
 | 
						|
        # Open the destination file in binary write mode
 | 
						|
        with open(destination, "wb") as file:
 | 
						|
            # Iterate over the file data in chunks
 | 
						|
            for chunk in response.iter_content(block_size):
 | 
						|
                progress_bar.update(len(chunk))  # Update progress bar
 | 
						|
                file.write(chunk)  # Write the chunk to the file
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
def download_file(url, destination):
 | 
						|
    # Send a GET request to download the file
 | 
						|
    with urllib.request.urlopen(url) as response:
 | 
						|
        # Get the total file size from headers, defaulting to 0 if not present
 | 
						|
        file_size = int(response.headers.get("Content-Length", 0))
 | 
						|
 | 
						|
        # Check if file exists and has the same size
 | 
						|
        if os.path.exists(destination):
 | 
						|
            file_size_local = os.path.getsize(destination)
 | 
						|
            if file_size == file_size_local:
 | 
						|
                print(f"File already exists and is up-to-date: {destination}")
 | 
						|
                return
 | 
						|
 | 
						|
        # Define the block size for reading the file
 | 
						|
        block_size = 1024  # 1 Kilobyte
 | 
						|
 | 
						|
        # Initialize the progress bar with total file size
 | 
						|
        progress_bar_description = os.path.basename(url)  # Extract filename from URL
 | 
						|
        with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
 | 
						|
            # Open the destination file in binary write mode
 | 
						|
            with open(destination, "wb") as file:
 | 
						|
                # Read the file in chunks and write to destination
 | 
						|
                while True:
 | 
						|
                    chunk = response.read(block_size)
 | 
						|
                    if not chunk:
 | 
						|
                        break
 | 
						|
                    file.write(chunk)
 | 
						|
                    progress_bar.update(len(chunk))  # Update progress bar
 | 
						|
 | 
						|
 | 
						|
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
 | 
						|
    # Initialize parameters dictionary with empty blocks for each layer
 | 
						|
    params = {"blocks": [{} for _ in range(settings["n_layer"])]}
 | 
						|
 | 
						|
    # Iterate over each variable in the checkpoint
 | 
						|
    for name, _ in tf.train.list_variables(ckpt_path):
 | 
						|
        # Load the variable and remove singleton dimensions
 | 
						|
        variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
 | 
						|
 | 
						|
        # Process the variable name to extract relevant parts
 | 
						|
        variable_name_parts = name.split("/")[1:]  # Skip the 'model/' prefix
 | 
						|
 | 
						|
        # Identify the target dictionary for the variable
 | 
						|
        target_dict = params
 | 
						|
        if variable_name_parts[0].startswith("h"):
 | 
						|
            layer_number = int(variable_name_parts[0][1:])
 | 
						|
            target_dict = params["blocks"][layer_number]
 | 
						|
 | 
						|
        # Recursively access or create nested dictionaries
 | 
						|
        for key in variable_name_parts[1:-1]:
 | 
						|
            target_dict = target_dict.setdefault(key, {})
 | 
						|
 | 
						|
        # Assign the variable array to the last key
 | 
						|
        last_key = variable_name_parts[-1]
 | 
						|
        target_dict[last_key] = variable_array
 | 
						|
 | 
						|
    return params
 | 
						|
 | 
						|
 | 
						|
def assign(left, right):
 | 
						|
    if left.shape != right.shape:
 | 
						|
        raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
 | 
						|
    return torch.nn.Parameter(torch.tensor(right))
 | 
						|
 | 
						|
 | 
						|
def load_weights_into_gpt(gpt, params):
 | 
						|
    gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
 | 
						|
    gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
 | 
						|
 | 
						|
    for b in range(len(params["blocks"])):
 | 
						|
        q_w, k_w, v_w = np.split(
 | 
						|
            (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
 | 
						|
        gpt.trf_blocks[b].att.W_query.weight = assign(
 | 
						|
            gpt.trf_blocks[b].att.W_query.weight, q_w.T)
 | 
						|
        gpt.trf_blocks[b].att.W_key.weight = assign(
 | 
						|
            gpt.trf_blocks[b].att.W_key.weight, k_w.T)
 | 
						|
        gpt.trf_blocks[b].att.W_value.weight = assign(
 | 
						|
            gpt.trf_blocks[b].att.W_value.weight, v_w.T)
 | 
						|
 | 
						|
        q_b, k_b, v_b = np.split(
 | 
						|
            (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
 | 
						|
        gpt.trf_blocks[b].att.W_query.bias = assign(
 | 
						|
            gpt.trf_blocks[b].att.W_query.bias, q_b)
 | 
						|
        gpt.trf_blocks[b].att.W_key.bias = assign(
 | 
						|
            gpt.trf_blocks[b].att.W_key.bias, k_b)
 | 
						|
        gpt.trf_blocks[b].att.W_value.bias = assign(
 | 
						|
            gpt.trf_blocks[b].att.W_value.bias, v_b)
 | 
						|
 | 
						|
        gpt.trf_blocks[b].att.out_proj.weight = assign(
 | 
						|
            gpt.trf_blocks[b].att.out_proj.weight,
 | 
						|
            params["blocks"][b]["attn"]["c_proj"]["w"].T)
 | 
						|
        gpt.trf_blocks[b].att.out_proj.bias = assign(
 | 
						|
            gpt.trf_blocks[b].att.out_proj.bias,
 | 
						|
            params["blocks"][b]["attn"]["c_proj"]["b"])
 | 
						|
 | 
						|
        gpt.trf_blocks[b].ff.layers[0].weight = assign(
 | 
						|
            gpt.trf_blocks[b].ff.layers[0].weight,
 | 
						|
            params["blocks"][b]["mlp"]["c_fc"]["w"].T)
 | 
						|
        gpt.trf_blocks[b].ff.layers[0].bias = assign(
 | 
						|
            gpt.trf_blocks[b].ff.layers[0].bias,
 | 
						|
            params["blocks"][b]["mlp"]["c_fc"]["b"])
 | 
						|
        gpt.trf_blocks[b].ff.layers[2].weight = assign(
 | 
						|
            gpt.trf_blocks[b].ff.layers[2].weight,
 | 
						|
            params["blocks"][b]["mlp"]["c_proj"]["w"].T)
 | 
						|
        gpt.trf_blocks[b].ff.layers[2].bias = assign(
 | 
						|
            gpt.trf_blocks[b].ff.layers[2].bias,
 | 
						|
            params["blocks"][b]["mlp"]["c_proj"]["b"])
 | 
						|
 | 
						|
        gpt.trf_blocks[b].norm1.scale = assign(
 | 
						|
            gpt.trf_blocks[b].norm1.scale,
 | 
						|
            params["blocks"][b]["ln_1"]["g"])
 | 
						|
        gpt.trf_blocks[b].norm1.shift = assign(
 | 
						|
            gpt.trf_blocks[b].norm1.shift,
 | 
						|
            params["blocks"][b]["ln_1"]["b"])
 | 
						|
        gpt.trf_blocks[b].norm2.scale = assign(
 | 
						|
            gpt.trf_blocks[b].norm2.scale,
 | 
						|
            params["blocks"][b]["ln_2"]["g"])
 | 
						|
        gpt.trf_blocks[b].norm2.shift = assign(
 | 
						|
            gpt.trf_blocks[b].norm2.shift,
 | 
						|
            params["blocks"][b]["ln_2"]["b"])
 | 
						|
 | 
						|
    gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
 | 
						|
    gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
 | 
						|
    gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
 | 
						|
 | 
						|
 | 
						|
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
 | 
						|
 | 
						|
    # For-loop is the same as before: Get logits, and only focus on last time step
 | 
						|
    for _ in range(max_new_tokens):
 | 
						|
        idx_cond = idx[:, -context_size:]
 | 
						|
        with torch.no_grad():
 | 
						|
            logits = model(idx_cond)
 | 
						|
        logits = logits[:, -1, :]
 | 
						|
 | 
						|
        # New: Filter logits with top_k sampling
 | 
						|
        if top_k is not None:
 | 
						|
            # Keep only top_k values
 | 
						|
            top_logits, _ = torch.topk(logits, top_k)
 | 
						|
            min_val = top_logits[:, -1]
 | 
						|
            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
 | 
						|
 | 
						|
        # New: Apply temperature scaling
 | 
						|
        if temperature > 0.0:
 | 
						|
            logits = logits / temperature
 | 
						|
 | 
						|
            # Apply softmax to get probabilities
 | 
						|
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)
 | 
						|
 | 
						|
            # Sample from the distribution
 | 
						|
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
 | 
						|
 | 
						|
        # Otherwise same as before: get idx of the vocab entry with the highest logits value
 | 
						|
        else:
 | 
						|
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)
 | 
						|
 | 
						|
        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
 | 
						|
            break
 | 
						|
 | 
						|
        # Same as before: append sampled index to the running sequence
 | 
						|
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)
 | 
						|
 | 
						|
    return idx
 | 
						|
 | 
						|
 | 
						|
def main(gpt_config, input_prompt, model_size):
 | 
						|
 | 
						|
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
						|
 | 
						|
    settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
 | 
						|
 | 
						|
    gpt = GPTModel(gpt_config)
 | 
						|
    load_weights_into_gpt(gpt, params)
 | 
						|
    gpt.to(device)
 | 
						|
    gpt.eval()
 | 
						|
 | 
						|
    tokenizer = tiktoken.get_encoding("gpt2")
 | 
						|
    torch.manual_seed(123)
 | 
						|
 | 
						|
    token_ids = generate(
 | 
						|
        model=gpt,
 | 
						|
        idx=text_to_token_ids(input_prompt, tokenizer),
 | 
						|
        max_new_tokens=25,
 | 
						|
        context_size=gpt_config["context_length"],
 | 
						|
        top_k=50,
 | 
						|
        temperature=1.0
 | 
						|
    )
 | 
						|
 | 
						|
    print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
 | 
						|
    torch.manual_seed(123)
 | 
						|
 | 
						|
    CHOOSE_MODEL = "gpt2-small (124M)"
 | 
						|
    INPUT_PROMPT = "Every effort moves you"
 | 
						|
 | 
						|
    BASE_CONFIG = {
 | 
						|
        "vocab_size": 50257,     # Vocabulary size
 | 
						|
        "context_length": 1024,  # Context length
 | 
						|
        "drop_rate": 0.0,        # Dropout rate
 | 
						|
        "qkv_bias": True         # Query-key-value bias
 | 
						|
    }
 | 
						|
 | 
						|
    model_configs = {
 | 
						|
        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
 | 
						|
        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
 | 
						|
        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
 | 
						|
        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
 | 
						|
    }
 | 
						|
 | 
						|
    model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
 | 
						|
 | 
						|
    BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
 | 
						|
 | 
						|
    main(BASE_CONFIG, INPUT_PROMPT, model_size)
 |