mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-06-26 23:50:03 +00:00
Llama 3 KV Cache (#685)
* Llama 3 KV Cache * skip expensive tests on Gh actions * Update __init__.py
This commit is contained in:
parent
c008f95072
commit
3be0f3202a
@ -216,3 +216,41 @@ The following table shows a performance comparison on an A100 for consequent `ge
|
||||
| --------------- | ---------- | ------- |
|
||||
| Llama3Model | 170 | 3.12 GB |
|
||||
| Llama3ModelFast | 177 | 3.61 GB |
|
||||
|
||||
|
||||
#### Pro tip 3: speed up inference with compilation
|
||||
|
||||
You can significantly boost inference performance using the KV cache `Llama3Model` drop-in replacement when running the model on a CPU. (See my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article to learn more about KV caches.)
|
||||
|
||||
```python
|
||||
from llms_from_scratch.kv_cache.llama3 import Llama3Model
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
|
||||
model = Llama3Model(LLAMA32_CONFIG)
|
||||
# ...
|
||||
token_ids = generate_text_simple(
|
||||
model=model,
|
||||
idx=text_to_token_ids(PROMPT, tokenizer).to(device),
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
context_size=LLAMA32_CONFIG["context_length"],
|
||||
)
|
||||
```
|
||||
|
||||
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
|
||||
|
||||
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|
||||
|-------------|-------------------|-----------------|------------|-------------------|
|
||||
| Llama3Model | Regular | Mac Mini M4 CPU | 1 | - |
|
||||
| Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - |
|
||||
| Llama3Model | KV cache | Mac Mini M4 CPU | 62 | - |
|
||||
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
|
||||
| | | | | |
|
||||
| Llama3Model | Regular | Mac Mini M4 GPU | 15 | - |
|
||||
| Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - |
|
||||
| Llama3Model | KV cache | Mac Mini M4 GPU | 62 | - |
|
||||
| Llama3Model | KV cache compiled | Mac Mini M4 GPU | - | - |
|
||||
| | | | | |
|
||||
| Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB |
|
||||
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
|
||||
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
|
||||
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |
|
@ -124,6 +124,9 @@ from llms_from_scratch.llama3 import (
|
||||
ChatFormat,
|
||||
clean_text
|
||||
)
|
||||
|
||||
from llms_from_scratch.kv_cache.llama3 import Llama3Model
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
```
|
||||
|
||||
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
4
pkg/llms_from_scratch/kv_cache/__init__.py
Normal file
4
pkg/llms_from_scratch/kv_cache/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
29
pkg/llms_from_scratch/kv_cache/generate.py
Normal file
29
pkg/llms_from_scratch/kv_cache/generate.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
|
||||
ctx_len = context_size or model.cfg["context_length"]
|
||||
|
||||
with torch.no_grad():
|
||||
if use_cache:
|
||||
model.reset_kv_cache()
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
logits = model(next_idx, use_cache=True)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
317
pkg/llms_from_scratch/kv_cache/llama3.py
Normal file
317
pkg/llms_from_scratch/kv_cache/llama3.py
Normal file
@ -0,0 +1,317 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from ..llama3 import Llama3Tokenizer, ChatFormat, clean_text # noqa: F401
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
LLAMA32_CONFIG_1B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"window_size": None, # Window size for the KV cache; context_length if None
|
||||
"emb_dim": 2048, # Embedding dimension
|
||||
"n_heads": 32, # Number of attention heads
|
||||
"n_layers": 16, # Number of layers
|
||||
"hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
|
||||
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
|
||||
"rope_base": 500_000.0, # The base in RoPE's "theta"
|
||||
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
|
||||
"rope_freq": { # RoPE frequency scaling
|
||||
"factor": 32.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_context_length": 8192,
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA32_CONFIG_3B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"window_size": None, # Window size for the KV cache; context_length if None
|
||||
"emb_dim": 3072, # Embedding dimension
|
||||
"n_heads": 24, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
"hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
|
||||
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
|
||||
"rope_base": 500_000.0, # The base in RoPE's "theta"
|
||||
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
|
||||
"rope_freq": { # RoPE frequency scaling
|
||||
"factor": 32.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_context_length": 8192,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class Llama3Model(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# Main model parameters
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||
|
||||
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
|
||||
)
|
||||
|
||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
context_length=cfg["context_length"],
|
||||
freq_config=cfg["rope_freq"]
|
||||
)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, in_idx, use_cache=False):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.cos, self.sin, use_cache)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
self.ptr_current_pos = 0
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = GroupedQueryAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
d_out=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
max_seq_len=cfg["context_length"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_fc1 = self.fc1(x)
|
||||
x_fc2 = self.fc2(x)
|
||||
x = nn.functional.silu(x_fc1) * x_fc2
|
||||
return self.fc3(x)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, d_out, num_heads, num_kv_groups, dtype=None, max_seq_len=None, window_size=None
|
||||
):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads
|
||||
|
||||
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||
self.num_kv_groups = num_kv_groups
|
||||
self.group_size = num_heads // num_kv_groups
|
||||
|
||||
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
||||
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
||||
|
||||
# For optional KV cache
|
||||
self.max_seq_len = max_seq_len
|
||||
self.window_size = window_size or self.max_seq_len
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
self.cache_initialized = False
|
||||
self.ptr = 0
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
b, num_tokens, d_in = x.shape
|
||||
|
||||
queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
|
||||
keys_new = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
|
||||
values_new = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
# Reshape queries, keys, and values
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
||||
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
||||
|
||||
# Transpose keys, values, and queries
|
||||
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
keys_new = keys_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||
values_new = values_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||
|
||||
# For KV cache
|
||||
pos_start = self.ptr
|
||||
pos_end = pos_start + num_tokens
|
||||
cos_slice = cos[pos_start:pos_end]
|
||||
sin_slice = sin[pos_start:pos_end]
|
||||
|
||||
# Apply RoPE
|
||||
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
|
||||
queries = apply_rope(queries, cos_slice, sin_slice)
|
||||
|
||||
# Expand keys and values to match the number of heads
|
||||
# Shape: (b, num_heads, num_tokens, head_dim)
|
||||
keys_new = keys_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
values_new = values_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
# For example, before repeat_interleave along dim=1 (query groups):
|
||||
# [K1, K2]
|
||||
# After repeat_interleave (each query group is repeated group_size times):
|
||||
# [K1, K1, K2, K2]
|
||||
# If we used regular repeat instead of repeat_interleave, we'd get:
|
||||
# [K1, K2, K1, K2]
|
||||
|
||||
if use_cache:
|
||||
if not self.cache_initialized:
|
||||
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
|
||||
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
|
||||
self.ptr = 0
|
||||
self.cache_initialized = True
|
||||
|
||||
# In-place update
|
||||
end = self.ptr + num_tokens
|
||||
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
|
||||
self.cache_v[:, :, self.ptr:end].copy_(values_new)
|
||||
|
||||
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
|
||||
values = self.cache_v[:, :, max(0, end - self.window_size):end]
|
||||
self.ptr = end
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
# Shape: (b, num_heads, num_tokens, num_tokens)
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
# Create causal mask to fill attention scores
|
||||
T_q = queries.shape[-2]
|
||||
T_k = keys.shape[-2]
|
||||
|
||||
if not use_cache or T_q > 1:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
|
||||
diagonal=1
|
||||
)
|
||||
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
assert keys.shape[-1] == self.head_dim
|
||||
|
||||
# Shape: (b, num_tokens, num_heads, head_dim)
|
||||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
|
||||
def reset_cache(self):
|
||||
if self.cache_k is not None:
|
||||
self.cache_k.zero_()
|
||||
self.cache_v.zero_()
|
||||
self.ptr = 0
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
|
||||
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
|
||||
|
||||
# Frequency adjustments
|
||||
if freq_config is not None:
|
||||
low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
|
||||
high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]
|
||||
|
||||
wavelen = 2 * torch.pi / inv_freq
|
||||
|
||||
inv_freq_llama = torch.where(
|
||||
wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
|
||||
)
|
||||
|
||||
smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
|
||||
freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
|
||||
)
|
||||
|
||||
smoothed_inv_freq = (
|
||||
(1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
|
||||
)
|
||||
|
||||
is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
|
||||
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||
inv_freq = inv_freq_llama
|
||||
|
||||
# Generate position indices
|
||||
positions = torch.arange(context_length, dtype=dtype)
|
||||
|
||||
# Compute the angles
|
||||
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
|
||||
|
||||
# Expand angles to match the head_dim
|
||||
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
|
||||
|
||||
# Precompute sine and cosine
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rope(x, cos, sin):
|
||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||
batch_size, num_heads, seq_len, head_dim = x.shape
|
||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||
|
||||
# Split x into first half and second half
|
||||
x1 = x[..., : head_dim // 2] # First half
|
||||
x2 = x[..., head_dim // 2:] # Second half
|
||||
|
||||
# Adjust sin and cos shapes
|
||||
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply the rotary transformation
|
||||
rotated = torch.cat((-x2, x1), dim=-1)
|
||||
x_rotated = (x * cos) + (rotated * sin)
|
||||
|
||||
# It's ok to use lower-precision after applying cos and sin rotation
|
||||
return x_rotated.to(dtype=x.dtype)
|
@ -12,8 +12,11 @@ from llms_from_scratch.llama3 import (
|
||||
GroupedQueryAttentionFast,
|
||||
Llama3Model,
|
||||
)
|
||||
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pytest
|
||||
import tiktoken
|
||||
import torch
|
||||
@ -180,8 +183,20 @@ def llama3_weights_path(tmp_path_factory):
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [Llama3Model])
|
||||
def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("GITHUB_ACTIONS") == "true",
|
||||
reason="Skipping in GitHub Actions due to compute or memory constraints"
|
||||
)
|
||||
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(LLAMA32_CONFIG_1B)
|
||||
model.load_state_dict(torch.load(llama3_weights_path))
|
||||
@ -198,7 +213,7 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
print("Encoded input text:", encoded)
|
||||
print("encoded_tensor.shape:", encoded_tensor.shape)
|
||||
|
||||
out = generate_text_simple(
|
||||
out = generate_fn(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=5,
|
||||
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.9"
|
||||
version = "1.0.12"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
Loading…
x
Reference in New Issue
Block a user