# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). # Source for "Build a Large Language Model From Scratch" # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch import json import os import psutil import urllib import torch import tqdm from torch.utils.data import Dataset def download_and_load_file(file_path, url): if not os.path.exists(file_path): with urllib.request.urlopen(url) as response: text_data = response.read().decode("utf-8") with open(file_path, "w", encoding="utf-8") as file: file.write(text_data) # The book originally contained this unnecessary "else" clause: # else: # with open(file_path, "r", encoding="utf-8") as file: # text_data = file.read() with open(file_path, "r", encoding="utf-8") as file: data = json.load(file) return data def format_input(entry): instruction_text = ( f"Below is an instruction that describes a task. " f"Write a response that appropriately completes the request." f"\n\n### Instruction:\n{entry['instruction']}" ) input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else "" return instruction_text + input_text class InstructionDataset(Dataset): def __init__(self, data, tokenizer): self.data = data # Pre-tokenize texts self.encoded_texts = [] for entry in data: instruction_plus_input = format_input(entry) response_text = f"\n\n### Response:\n{entry['output']}" full_text = instruction_plus_input + response_text self.encoded_texts.append( tokenizer.encode(full_text) ) def __getitem__(self, index): return self.encoded_texts[index] def __len__(self): return len(self.data) def custom_collate_draft_1( batch, pad_token_id=50256, device="cpu" ): # Find the longest sequence in the batch # and increase the max length by +1, which will add one extra # padding token below batch_max_length = max(len(item)+1 for item in batch) # Pad and prepare inputs inputs_lst = [] for item in batch: new_item = item.copy() # Add an <|endoftext|> token new_item += [pad_token_id] # Pad sequences to batch_max_length padded = ( new_item + [pad_token_id] * (batch_max_length - len(new_item)) ) # Via padded[:-1], we remove the extra padded token # that has been added via the +1 setting in batch_max_length # (the extra padding token will be relevant in later codes) inputs = torch.tensor(padded[:-1]) inputs_lst.append(inputs) # Convert list of inputs to tensor and transfer to target device inputs_tensor = torch.stack(inputs_lst).to(device) return inputs_tensor def custom_collate_draft_2( batch, pad_token_id=50256, device="cpu" ): # Find the longest sequence in the batch batch_max_length = max(len(item)+1 for item in batch) # Pad and prepare inputs inputs_lst, targets_lst = [], [] for item in batch: new_item = item.copy() # Add an <|endoftext|> token new_item += [pad_token_id] # Pad sequences to max_length padded = ( new_item + [pad_token_id] * (batch_max_length - len(new_item)) ) inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets inputs_lst.append(inputs) targets_lst.append(targets) # Convert list of inputs to tensor and transfer to target device inputs_tensor = torch.stack(inputs_lst).to(device) targets_tensor = torch.stack(targets_lst).to(device) return inputs_tensor, targets_tensor def custom_collate_fn( batch, pad_token_id=50256, ignore_index=-100, allowed_max_length=None, device="cpu" ): # Find the longest sequence in the batch batch_max_length = max(len(item)+1 for item in batch) # Pad and prepare inputs and targets inputs_lst, targets_lst = [], [] for item in batch: new_item = item.copy() # Add an <|endoftext|> token new_item += [pad_token_id] # Pad sequences to max_length padded = ( new_item + [pad_token_id] * (batch_max_length - len(new_item)) ) inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets # New: Replace all but the first padding tokens in targets by ignore_index mask = targets == pad_token_id indices = torch.nonzero(mask).squeeze() if indices.numel() > 1: targets[indices[1:]] = ignore_index # New: Optionally truncate to maximum sequence length if allowed_max_length is not None: inputs = inputs[:allowed_max_length] targets = targets[:allowed_max_length] inputs_lst.append(inputs) targets_lst.append(targets) # Convert list of inputs and targets to tensors and transfer to target device inputs_tensor = torch.stack(inputs_lst).to(device) targets_tensor = torch.stack(targets_lst).to(device) return inputs_tensor, targets_tensor def check_if_running(process_name): running = False for proc in psutil.process_iter(["name"]): if process_name in proc.info["name"]: running = True break return running def query_model( prompt, model="llama3", url="http://localhost:11434/api/chat" ): # Create the data payload as a dictionary data = { "model": model, "messages": [ {"role": "user", "content": prompt} ], "options": { # Settings below are required for deterministic responses "seed": 123, "temperature": 0, "num_ctx": 2048 } } # Convert the dictionary to a JSON formatted string and encode it to bytes payload = json.dumps(data).encode("utf-8") # Create a request object, setting the method to POST and adding necessary headers request = urllib.request.Request( url, data=payload, method="POST" ) request.add_header("Content-Type", "application/json") # Send the request and capture the response response_data = "" with urllib.request.urlopen(request) as response: # Read and decode the response while True: line = response.readline().decode("utf-8") if not line: break response_json = json.loads(line) response_data += response_json["message"]["content"] return response_data def generate_model_scores(json_data, json_key, model="llama3"): scores = [] for entry in tqdm(json_data, desc="Scoring entries"): prompt = ( f"Given the input `{format_input(entry)}` " f"and correct output `{entry['output']}`, " f"score the model response `{entry[json_key]}`" f" on a scale from 0 to 100, where 100 is the best score. " f"Respond with the integer number only." ) score = query_model(prompt, model) try: scores.append(int(score)) except ValueError: print(f"Could not convert score: {score}") continue return scores