mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			94 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			94 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | 
						|
# Source for "Build a Large Language Model From Scratch"
 | 
						|
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | 
						|
# Code: https://github.com/rasbt/LLMs-from-scratch
 | 
						|
 | 
						|
from pathlib import Path
 | 
						|
import sys
 | 
						|
 | 
						|
import tiktoken
 | 
						|
import torch
 | 
						|
import chainlit
 | 
						|
 | 
						|
from previous_chapters import (
 | 
						|
    generate,
 | 
						|
    GPTModel,
 | 
						|
    text_to_token_ids,
 | 
						|
    token_ids_to_text,
 | 
						|
)
 | 
						|
 | 
						|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
						|
 | 
						|
 | 
						|
def get_model_and_tokenizer():
 | 
						|
    """
 | 
						|
    Code to load a GPT-2 model with finetuned weights generated in chapter 7.
 | 
						|
    This requires that you run the code in chapter 7 first, which generates the necessary gpt2-medium355M-sft.pth file.
 | 
						|
    """
 | 
						|
 | 
						|
    GPT_CONFIG_355M = {
 | 
						|
        "vocab_size": 50257,     # Vocabulary size
 | 
						|
        "context_length": 1024,  # Shortened context length (orig: 1024)
 | 
						|
        "emb_dim": 1024,         # Embedding dimension
 | 
						|
        "n_heads": 16,           # Number of attention heads
 | 
						|
        "n_layers": 24,          # Number of layers
 | 
						|
        "drop_rate": 0.0,        # Dropout rate
 | 
						|
        "qkv_bias": True         # Query-key-value bias
 | 
						|
    }
 | 
						|
 | 
						|
    tokenizer = tiktoken.get_encoding("gpt2")
 | 
						|
 | 
						|
    model_path = Path("..") / "01_main-chapter-code" / "gpt2-medium355M-sft.pth"
 | 
						|
    if not model_path.exists():
 | 
						|
        print(
 | 
						|
            f"Could not find the {model_path} file. Please run the chapter 7 code "
 | 
						|
            " (ch07.ipynb) to generate the gpt2-medium355M-sft.pt file."
 | 
						|
        )
 | 
						|
        sys.exit()
 | 
						|
 | 
						|
    checkpoint = torch.load(model_path, weights_only=True)
 | 
						|
    model = GPTModel(GPT_CONFIG_355M)
 | 
						|
    model.load_state_dict(checkpoint)
 | 
						|
    model.to(device)
 | 
						|
 | 
						|
    return tokenizer, model, GPT_CONFIG_355M
 | 
						|
 | 
						|
 | 
						|
def extract_response(response_text, input_text):
 | 
						|
    return response_text[len(input_text):].replace("### Response:", "").strip()
 | 
						|
 | 
						|
 | 
						|
# Obtain the necessary tokenizer and model files for the chainlit function below
 | 
						|
tokenizer, model, model_config = get_model_and_tokenizer()
 | 
						|
 | 
						|
 | 
						|
@chainlit.on_message
 | 
						|
async def main(message: chainlit.Message):
 | 
						|
    """
 | 
						|
    The main Chainlit function.
 | 
						|
    """
 | 
						|
 | 
						|
    torch.manual_seed(123)
 | 
						|
 | 
						|
    prompt = f"""Below is an instruction that describes a task. Write a response
 | 
						|
    that appropriately completes the request.
 | 
						|
 | 
						|
    ### Instruction:
 | 
						|
    {message.content}
 | 
						|
    """
 | 
						|
 | 
						|
    token_ids = generate(  # function uses `with torch.no_grad()` internally already
 | 
						|
        model=model,
 | 
						|
        idx=text_to_token_ids(prompt, tokenizer).to(device),  # The user text is provided via as `message.content`
 | 
						|
        max_new_tokens=35,
 | 
						|
        context_size=model_config["context_length"],
 | 
						|
        eos_id=50256
 | 
						|
    )
 | 
						|
 | 
						|
    text = token_ids_to_text(token_ids, tokenizer)
 | 
						|
    response = extract_response(text, prompt)
 | 
						|
 | 
						|
    await chainlit.Message(
 | 
						|
        content=f"{response}",  # This returns the model response to the interface
 | 
						|
    ).send()
 |