mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-18 05:31:40 +00:00
parent
dc1b1a05b0
commit
0467c8289b
2
.gitignore
vendored
2
.gitignore
vendored
@ -34,6 +34,8 @@ ch05/01_main-chapter-code/model.pth
|
|||||||
ch05/01_main-chapter-code/model_and_optimizer.pth
|
ch05/01_main-chapter-code/model_and_optimizer.pth
|
||||||
ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
|
ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
|
||||||
ch05/06_user_interface/gpt2
|
ch05/06_user_interface/gpt2
|
||||||
|
ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b
|
||||||
|
ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b-chat
|
||||||
|
|
||||||
ch06/01_main-chapter-code/gpt2
|
ch06/01_main-chapter-code/gpt2
|
||||||
ch06/02_bonus_additional-experiments/gpt2
|
ch06/02_bonus_additional-experiments/gpt2
|
||||||
|
@ -116,6 +116,7 @@ Several folders contain optional materials as a bonus for interested readers:
|
|||||||
- [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers)
|
- [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers)
|
||||||
- [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning)
|
- [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning)
|
||||||
- [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface)
|
- [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface)
|
||||||
|
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
|
||||||
- **Chapter 6:**
|
- **Chapter 6:**
|
||||||
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
|
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
|
||||||
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)
|
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)
|
||||||
|
7
ch05/07_gpt_to_llama/README.md
Normal file
7
ch05/07_gpt_to_llama/README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Converting GPT to Llama
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
This folder contains code for converting the GPT implementation from chapter 4 and 5 to Meta AI's Llama architecture:
|
||||||
|
|
||||||
|
- [converting-gpt-to-llama2.ipynb](converting-gpt-to-llama2.ipynb): contains code to convert GPT to Llama 2 7B step by step and loads pretrained weights from Meta AI
|
1568
ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb
Normal file
1568
ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
63
ch05/07_gpt_to_llama/previous_chapters.py
Normal file
63
ch05/07_gpt_to_llama/previous_chapters.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
#
|
||||||
|
# This file collects all the relevant code that we covered thus far
|
||||||
|
# throughout Chapters 2-4.
|
||||||
|
# This file can be run as a standalone script.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
# Chapter 5
|
||||||
|
#####################################
|
||||||
|
def text_to_token_ids(text, tokenizer):
|
||||||
|
encoded = tokenizer.encode(text)
|
||||||
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
|
||||||
|
return encoded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def token_ids_to_text(token_ids, tokenizer):
|
||||||
|
flat = token_ids.squeeze(0) # remove batch dimension
|
||||||
|
return tokenizer.decode(flat.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
|
||||||
|
|
||||||
|
# For-loop is the same as before: Get logits, and only focus on last time step
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
idx_cond = idx[:, -context_size:]
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx_cond)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
# New: Filter logits with top_k sampling
|
||||||
|
if top_k is not None:
|
||||||
|
# Keep only top_k values
|
||||||
|
top_logits, _ = torch.topk(logits, top_k)
|
||||||
|
min_val = top_logits[:, -1]
|
||||||
|
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
||||||
|
|
||||||
|
# New: Apply temperature scaling
|
||||||
|
if temperature > 0.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# Apply softmax to get probabilities
|
||||||
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
|
||||||
|
# Sample from the distribution
|
||||||
|
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
||||||
|
|
||||||
|
# Otherwise same as before: get idx of the vocab entry with the highest logits value
|
||||||
|
else:
|
||||||
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
||||||
|
|
||||||
|
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
|
||||||
|
break
|
||||||
|
|
||||||
|
# Same as before: append sampled index to the running sequence
|
||||||
|
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
||||||
|
|
||||||
|
return idx
|
2
ch05/07_gpt_to_llama/requirements-extra.txt
Normal file
2
ch05/07_gpt_to_llama/requirements-extra.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
huggingface_hub>=0.24.7
|
||||||
|
sentencepiece>=0.1.99
|
@ -11,3 +11,4 @@
|
|||||||
- [04_learning_rate_schedulers](04_learning_rate_schedulers) contains code implementing a more sophisticated training function including learning rate schedulers and gradient clipping
|
- [04_learning_rate_schedulers](04_learning_rate_schedulers) contains code implementing a more sophisticated training function including learning rate schedulers and gradient clipping
|
||||||
- [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script
|
- [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script
|
||||||
- [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM
|
- [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM
|
||||||
|
- [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama and loads pretrained weights from Meta AI
|
||||||
|
Loading…
x
Reference in New Issue
Block a user