mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 11:50:14 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			248 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			248 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | 
						|
# Source for "Build a Large Language Model From Scratch"
 | 
						|
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | 
						|
# Code: https://github.com/rasbt/LLMs-from-scratch
 | 
						|
 | 
						|
import json
 | 
						|
import os
 | 
						|
import psutil
 | 
						|
import urllib
 | 
						|
 | 
						|
import torch
 | 
						|
from tqdm import tqdm
 | 
						|
from torch.utils.data import Dataset
 | 
						|
 | 
						|
 | 
						|
def download_and_load_file(file_path, url):
 | 
						|
 | 
						|
    if not os.path.exists(file_path):
 | 
						|
        with urllib.request.urlopen(url) as response:
 | 
						|
            text_data = response.read().decode("utf-8")
 | 
						|
        with open(file_path, "w", encoding="utf-8") as file:
 | 
						|
            file.write(text_data)
 | 
						|
 | 
						|
    # The book originally contained this unnecessary "else" clause:
 | 
						|
    # else:
 | 
						|
    #     with open(file_path, "r", encoding="utf-8") as file:
 | 
						|
    #         text_data = file.read()
 | 
						|
 | 
						|
    with open(file_path, "r", encoding="utf-8") as file:
 | 
						|
        data = json.load(file)
 | 
						|
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
def format_input(entry):
 | 
						|
    instruction_text = (
 | 
						|
        f"Below is an instruction that describes a task. "
 | 
						|
        f"Write a response that appropriately completes the request."
 | 
						|
        f"\n\n### Instruction:\n{entry['instruction']}"
 | 
						|
    )
 | 
						|
 | 
						|
    input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
 | 
						|
 | 
						|
    return instruction_text + input_text
 | 
						|
 | 
						|
 | 
						|
class InstructionDataset(Dataset):
 | 
						|
    def __init__(self, data, tokenizer):
 | 
						|
        self.data = data
 | 
						|
 | 
						|
        # Pre-tokenize texts
 | 
						|
        self.encoded_texts = []
 | 
						|
        for entry in data:
 | 
						|
            instruction_plus_input = format_input(entry)
 | 
						|
            response_text = f"\n\n### Response:\n{entry['output']}"
 | 
						|
            full_text = instruction_plus_input + response_text
 | 
						|
            self.encoded_texts.append(
 | 
						|
                tokenizer.encode(full_text)
 | 
						|
            )
 | 
						|
 | 
						|
    def __getitem__(self, index):
 | 
						|
        return self.encoded_texts[index]
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        return len(self.data)
 | 
						|
 | 
						|
 | 
						|
def custom_collate_draft_1(
 | 
						|
    batch,
 | 
						|
    pad_token_id=50256,
 | 
						|
    device="cpu"
 | 
						|
):
 | 
						|
    # Find the longest sequence in the batch
 | 
						|
    # and increase the max length by +1, which will add one extra
 | 
						|
    # padding token below
 | 
						|
    batch_max_length = max(len(item)+1 for item in batch)
 | 
						|
 | 
						|
    # Pad and prepare inputs
 | 
						|
    inputs_lst = []
 | 
						|
 | 
						|
    for item in batch:
 | 
						|
        new_item = item.copy()
 | 
						|
        # Add an <|endoftext|> token
 | 
						|
        new_item += [pad_token_id]
 | 
						|
        # Pad sequences to batch_max_length
 | 
						|
        padded = (
 | 
						|
            new_item + [pad_token_id] *
 | 
						|
            (batch_max_length - len(new_item))
 | 
						|
        )
 | 
						|
        # Via padded[:-1], we remove the extra padded token
 | 
						|
        # that has been added via the +1 setting in batch_max_length
 | 
						|
        # (the extra padding token will be relevant in later codes)
 | 
						|
        inputs = torch.tensor(padded[:-1])
 | 
						|
        inputs_lst.append(inputs)
 | 
						|
 | 
						|
    # Convert list of inputs to tensor and transfer to target device
 | 
						|
    inputs_tensor = torch.stack(inputs_lst).to(device)
 | 
						|
    return inputs_tensor
 | 
						|
 | 
						|
 | 
						|
def custom_collate_draft_2(
 | 
						|
    batch,
 | 
						|
    pad_token_id=50256,
 | 
						|
    device="cpu"
 | 
						|
):
 | 
						|
    # Find the longest sequence in the batch
 | 
						|
    batch_max_length = max(len(item)+1 for item in batch)
 | 
						|
 | 
						|
    # Pad and prepare inputs
 | 
						|
    inputs_lst, targets_lst = [], []
 | 
						|
 | 
						|
    for item in batch:
 | 
						|
        new_item = item.copy()
 | 
						|
        # Add an <|endoftext|> token
 | 
						|
        new_item += [pad_token_id]
 | 
						|
        # Pad sequences to max_length
 | 
						|
        padded = (
 | 
						|
            new_item + [pad_token_id] *
 | 
						|
            (batch_max_length - len(new_item))
 | 
						|
        )
 | 
						|
        inputs = torch.tensor(padded[:-1])  # Truncate the last token for inputs
 | 
						|
        targets = torch.tensor(padded[1:])  # Shift +1 to the right for targets
 | 
						|
        inputs_lst.append(inputs)
 | 
						|
        targets_lst.append(targets)
 | 
						|
 | 
						|
    # Convert list of inputs to tensor and transfer to target device
 | 
						|
    inputs_tensor = torch.stack(inputs_lst).to(device)
 | 
						|
    targets_tensor = torch.stack(targets_lst).to(device)
 | 
						|
    return inputs_tensor, targets_tensor
 | 
						|
 | 
						|
 | 
						|
def custom_collate_fn(
 | 
						|
    batch,
 | 
						|
    pad_token_id=50256,
 | 
						|
    ignore_index=-100,
 | 
						|
    allowed_max_length=None,
 | 
						|
    device="cpu"
 | 
						|
):
 | 
						|
    # Find the longest sequence in the batch
 | 
						|
    batch_max_length = max(len(item)+1 for item in batch)
 | 
						|
 | 
						|
    # Pad and prepare inputs and targets
 | 
						|
    inputs_lst, targets_lst = [], []
 | 
						|
 | 
						|
    for item in batch:
 | 
						|
        new_item = item.copy()
 | 
						|
        # Add an <|endoftext|> token
 | 
						|
        new_item += [pad_token_id]
 | 
						|
        # Pad sequences to max_length
 | 
						|
        padded = (
 | 
						|
            new_item + [pad_token_id] *
 | 
						|
            (batch_max_length - len(new_item))
 | 
						|
        )
 | 
						|
        inputs = torch.tensor(padded[:-1])  # Truncate the last token for inputs
 | 
						|
        targets = torch.tensor(padded[1:])  # Shift +1 to the right for targets
 | 
						|
 | 
						|
        # New: Replace all but the first padding tokens in targets by ignore_index
 | 
						|
        mask = targets == pad_token_id
 | 
						|
        indices = torch.nonzero(mask).squeeze()
 | 
						|
        if indices.numel() > 1:
 | 
						|
            targets[indices[1:]] = ignore_index
 | 
						|
 | 
						|
        # New: Optionally truncate to maximum sequence length
 | 
						|
        if allowed_max_length is not None:
 | 
						|
            inputs = inputs[:allowed_max_length]
 | 
						|
            targets = targets[:allowed_max_length]
 | 
						|
 | 
						|
        inputs_lst.append(inputs)
 | 
						|
        targets_lst.append(targets)
 | 
						|
 | 
						|
    # Convert list of inputs and targets to tensors and transfer to target device
 | 
						|
    inputs_tensor = torch.stack(inputs_lst).to(device)
 | 
						|
    targets_tensor = torch.stack(targets_lst).to(device)
 | 
						|
 | 
						|
    return inputs_tensor, targets_tensor
 | 
						|
 | 
						|
 | 
						|
def check_if_running(process_name):
 | 
						|
    running = False
 | 
						|
    for proc in psutil.process_iter(["name"]):
 | 
						|
        if process_name in proc.info["name"]:
 | 
						|
            running = True
 | 
						|
            break
 | 
						|
    return running
 | 
						|
 | 
						|
 | 
						|
def query_model(
 | 
						|
    prompt,
 | 
						|
    model="llama3",
 | 
						|
    url="http://localhost:11434/api/chat"
 | 
						|
):
 | 
						|
    # Create the data payload as a dictionary
 | 
						|
    data = {
 | 
						|
        "model": model,
 | 
						|
        "messages": [
 | 
						|
            {"role": "user", "content": prompt}
 | 
						|
        ],
 | 
						|
        "options": {     # Settings below are required for deterministic responses
 | 
						|
            "seed": 123,
 | 
						|
            "temperature": 0,
 | 
						|
            "num_ctx": 2048
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    # Convert the dictionary to a JSON formatted string and encode it to bytes
 | 
						|
    payload = json.dumps(data).encode("utf-8")
 | 
						|
 | 
						|
    # Create a request object, setting the method to POST and adding necessary headers
 | 
						|
    request = urllib.request.Request(
 | 
						|
        url,
 | 
						|
        data=payload,
 | 
						|
        method="POST"
 | 
						|
    )
 | 
						|
    request.add_header("Content-Type", "application/json")
 | 
						|
 | 
						|
    # Send the request and capture the response
 | 
						|
    response_data = ""
 | 
						|
    with urllib.request.urlopen(request) as response:
 | 
						|
        # Read and decode the response
 | 
						|
        while True:
 | 
						|
            line = response.readline().decode("utf-8")
 | 
						|
            if not line:
 | 
						|
                break
 | 
						|
            response_json = json.loads(line)
 | 
						|
            response_data += response_json["message"]["content"]
 | 
						|
 | 
						|
    return response_data
 | 
						|
 | 
						|
 | 
						|
def generate_model_scores(json_data, json_key, model="llama3"):
 | 
						|
    scores = []
 | 
						|
    for entry in tqdm(json_data, desc="Scoring entries"):
 | 
						|
        prompt = (
 | 
						|
            f"Given the input `{format_input(entry)}` "
 | 
						|
            f"and correct output `{entry['output']}`, "
 | 
						|
            f"score the model response `{entry[json_key]}`"
 | 
						|
            f" on a scale from 0 to 100, where 100 is the best score. "
 | 
						|
            f"Respond with the integer number only."
 | 
						|
        )
 | 
						|
        score = query_model(prompt, model)
 | 
						|
        try:
 | 
						|
            scores.append(int(score))
 | 
						|
        except ValueError:
 | 
						|
            print(f"Could not convert score: {score}")
 | 
						|
            continue
 | 
						|
 | 
						|
    return scores
 |