mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-03 11:20:49 +00:00
Batched KV Cache Inference for Qwen3 (#735)
This commit is contained in:
parent
7dc1dcbe27
commit
a200698698
@ -292,4 +292,72 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
|
||||
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
||||
|
||||
|
||||
|
||||
#### Pro tip 3: batched inference
|
||||
|
||||
We can further increase the throughput via batched inference. While it's not an apples-to-apples comparison, as we are now running inference with a higher number of input sequences, this increases the tokens per second throughput while trading it off against increased memory usage.
|
||||
|
||||
This only requires a small code modification with respect to preparing the prompt. For example, consider this batched prompt below:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.qwen3 import Qwen3Model, QWEN_CONFIG_06_B
|
||||
# ...
|
||||
|
||||
prompts = [
|
||||
"Give me a short introduction to neural networks.",
|
||||
"Give me a short introduction to machine learning.",
|
||||
"Give me a short introduction to deep learning models.",
|
||||
"Give me a short introduction to natural language processing.",
|
||||
"Give me a short introduction to generative AI systems.",
|
||||
"Give me a short introduction to transformer architectures.",
|
||||
"Give me a short introduction to supervised learning methods.",
|
||||
"Give me a short introduction to unsupervised learning.",
|
||||
]
|
||||
|
||||
tokenized_prompts = [tokenizer.encode(p) for p in prompts]
|
||||
max_len = max(len(t) for t in tokenized_prompts)
|
||||
padded_token_ids = [
|
||||
t + [tokenizer.pad_token_id] * (max_len - len(t)) for t in tokenized_prompts
|
||||
]
|
||||
input_tensor = torch.tensor(padded_token_ids).to(device)
|
||||
|
||||
output_token_ids = generate_text_simple(
|
||||
model=model,
|
||||
idx=input_tensor,
|
||||
max_new_tokens=150,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"],
|
||||
)
|
||||
```
|
||||
|
||||
The code for the KV cache version is similar, except that it requires using these drop-in replacements:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model
|
||||
```
|
||||
|
||||
|
||||
The experiments below are run with a batch size of 8.
|
||||
|
||||
| Model | Mode | Hardware | Batch size | Tokens/sec | GPU Memory (VRAM) |
|
||||
| ---------- | ----------------- | --------------- | ---------- | ---------- | ----------------- |
|
||||
| Qwen3Model | Regular | Mac Mini M4 CPU | 8 | 2 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 8 | - | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 CPU | 8 | 92 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 8 | 128 | - |
|
||||
| | | | | | |
|
||||
| Qwen3Model | Regular | Mac Mini M4 GPU | 8 | 36 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | 8 | - | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 GPU | 8 | 61 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | 8 | - | - |
|
||||
| | | | | | |
|
||||
| Qwen3Model | Regular | Nvidia A100 GPU | 8 | 184 | 2.19 GB |
|
||||
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 8 | 351 | 2.19 GB |
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 8 | 140 | 3.13 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 8 | 280 | 1.75 GB |
|
||||
|
||||
|
||||
|
||||
@ -161,6 +161,10 @@ from llms_from_scratch.qwen3 import (
|
||||
# KV cache drop-in replacements
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
|
||||
# KV cache drop-in replacements with batched inference support
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model
|
||||
```
|
||||
|
||||
For the `llms_from_scratch.qwen3` usage information, please see [this bonus section](../../ch05/11_qwen3/README.md).
|
||||
|
||||
4
pkg/llms_from_scratch/kv_cache_batched/__init__.py
Normal file
4
pkg/llms_from_scratch/kv_cache_batched/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
50
pkg/llms_from_scratch/kv_cache_batched/generate.py
Normal file
50
pkg/llms_from_scratch/kv_cache_batched/generate.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from .utils import KVCache
|
||||
import torch
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
ctx_len = context_size or model.cfg["context_length"]
|
||||
batch_size = idx.size(0)
|
||||
|
||||
with torch.no_grad():
|
||||
if use_cache:
|
||||
# initialize cache and positions
|
||||
cache = KVCache(n_layers=model.cfg["n_layers"], batch_size=batch_size)
|
||||
model.reset_kv_cache(batch_size=batch_size, device=idx.device)
|
||||
|
||||
# initial full-context pass
|
||||
input_ids = idx[:, -ctx_len:]
|
||||
seq_len = input_ids.size(1)
|
||||
start_pos = model.current_pos.clone()
|
||||
logits = model(
|
||||
input_ids,
|
||||
cache=cache,
|
||||
start_pos=start_pos
|
||||
)
|
||||
model.current_pos += seq_len
|
||||
|
||||
# iterative generation
|
||||
for _ in range(max_new_tokens):
|
||||
next_token = logits[:, -1].argmax(dim=-1, keepdim=True) # (B, 1)
|
||||
logits = model(
|
||||
next_token,
|
||||
cache=cache,
|
||||
start_pos=model.current_pos.clone()
|
||||
)
|
||||
model.current_pos += 1
|
||||
idx = torch.cat([idx, next_token], dim=1)
|
||||
else:
|
||||
# no cache
|
||||
for _ in range(max_new_tokens):
|
||||
input_ids = idx[:, -ctx_len:]
|
||||
logits = model(input_ids, cache=None, start_pos=None)
|
||||
next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_token], dim=1)
|
||||
|
||||
return idx
|
||||
287
pkg/llms_from_scratch/kv_cache_batched/qwen3.py
Normal file
287
pkg/llms_from_scratch/kv_cache_batched/qwen3.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from .utils import KVCache # noqa: F401
|
||||
from ..qwen3 import ( # noqa: F401
|
||||
QWEN_CONFIG_06_B, QWEN3_CONFIG_1_7B, QWEN3_CONFIG_4B,
|
||||
QWEN3_CONFIG_8B, QWEN3_CONFIG_14B, QWEN3_CONFIG_32B,
|
||||
Qwen3Tokenizer, load_weights_into_qwen,
|
||||
download_from_huggingface,
|
||||
download_from_huggingface_from_snapshots
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# Main model parameters
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||
|
||||
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
|
||||
)
|
||||
self.final_norm = RMSNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusable utilities
|
||||
if cfg["head_dim"] is None:
|
||||
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
||||
else:
|
||||
head_dim = cfg["head_dim"]
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=head_dim,
|
||||
theta_base=cfg["rope_base"],
|
||||
context_length=cfg["context_length"]
|
||||
)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
self.current_pos = None # Batched version tracks positions per sample
|
||||
|
||||
def forward(self, in_idx, cache=None, start_pos=None):
|
||||
B, num_tokens = in_idx.size()
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
device = x.device
|
||||
|
||||
if cache is not None:
|
||||
pos_start = start_pos
|
||||
pos_end = pos_start + num_tokens
|
||||
max_len = pos_end.max().item()
|
||||
full_mask = torch.triu(
|
||||
torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1
|
||||
)
|
||||
mask = torch.zeros(B, 1, num_tokens, max_len, device=device, dtype=torch.bool)
|
||||
for i in range(B):
|
||||
ps, pe = pos_start[i].item(), pos_end[i].item()
|
||||
mask[i, 0] = full_mask[ps:pe, :pe]
|
||||
else:
|
||||
pos_start = torch.zeros(B, dtype=torch.long, device=device)
|
||||
mask = torch.triu(
|
||||
torch.ones(num_tokens, num_tokens, device=device, dtype=torch.bool), diagonal=1
|
||||
)[None, None, :, :]
|
||||
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = [cache.get(i, b_idx) for b_idx in range(B)] if cache is not None else None
|
||||
x, new_blk_cache = block(x, mask, self.cos, self.sin, start_pos=pos_start, cache=blk_cache)
|
||||
if cache is not None:
|
||||
for b_idx in range(B):
|
||||
cache.update(i, b_idx, new_blk_cache[b_idx])
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
def reset_kv_cache(self, batch_size, device=None):
|
||||
device = device or next(self.parameters()).device
|
||||
self.current_pos = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = GroupedQueryAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
qk_norm=cfg["qk_norm"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
|
||||
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x, next_cache
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_fc1 = self.fc1(x)
|
||||
x_fc2 = self.fc2(x)
|
||||
x = nn.functional.silu(x_fc1) * x_fc2
|
||||
return self.fc3(x)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None):
|
||||
super().__init__()
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_groups = num_kv_groups
|
||||
self.group_size = num_heads // num_kv_groups
|
||||
|
||||
if head_dim is None:
|
||||
assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
|
||||
head_dim = d_in // num_heads
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.d_out = num_heads * head_dim
|
||||
|
||||
self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
|
||||
self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
|
||||
self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
|
||||
|
||||
self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.q_norm = self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
# Apply projections
|
||||
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
|
||||
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
# Reshape
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Optional normalization
|
||||
if self.q_norm:
|
||||
queries = self.q_norm(queries)
|
||||
if self.k_norm:
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
# Apply RoPE
|
||||
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
||||
keys = apply_rope(keys, cos, sin, offset=start_pos)
|
||||
|
||||
# KV caching
|
||||
next_cache = []
|
||||
for i in range(b):
|
||||
prev = cache[i] if cache else None
|
||||
if prev is None:
|
||||
k_cat = keys[i:i+1]
|
||||
v_cat = values[i:i+1]
|
||||
else:
|
||||
prev_k, prev_v = prev
|
||||
k_cat = torch.cat([prev_k, keys[i:i+1]], dim=2)
|
||||
v_cat = torch.cat([prev_v, values[i:i+1]], dim=2)
|
||||
next_cache.append((k_cat, v_cat))
|
||||
|
||||
keys = torch.cat([k for k, _ in next_cache], dim=0)
|
||||
values = torch.cat([v for _, v in next_cache], dim=0)
|
||||
|
||||
# Expand K and V to match number of heads
|
||||
keys = keys.repeat_interleave(self.group_size, dim=1)
|
||||
values = values.repeat_interleave(self.group_size, dim=1)
|
||||
|
||||
# Attention
|
||||
attn_scores = queries @ keys.transpose(2, 3)
|
||||
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
|
||||
|
||||
# attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
|
||||
# PyTorch fails to do the implicit casting, so we have to be intentional with the types
|
||||
scale = torch.tensor(self.head_dim**0.5, dtype=queries.dtype, device=queries.device)
|
||||
attn_weights = torch.softmax(attn_scores / scale, dim=-1).to(values.dtype)
|
||||
|
||||
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||
return self.out_proj(context), next_cache
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
|
||||
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
|
||||
|
||||
# Generate position indices
|
||||
positions = torch.arange(context_length, dtype=dtype)
|
||||
|
||||
# Compute the angles
|
||||
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
|
||||
|
||||
# Expand angles to match the head_dim
|
||||
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
|
||||
|
||||
# Precompute sine and cosine
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rope(x, cos, sin, offset):
|
||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||
bsz, n_heads, seq_len, head_dim = x.shape
|
||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||
assert offset.shape[0] == bsz, "Offset must have one value per batch item"
|
||||
|
||||
# Prepare cos/sin: (seq_len, head_dim)
|
||||
cos = cos[:cos.shape[0], :].unsqueeze(0).unsqueeze(0) # (1, 1, total_seq_len, head_dim)
|
||||
sin = sin[:sin.shape[0], :].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Build position indices per batch item
|
||||
position_ids = torch.arange(seq_len, device=offset.device).unsqueeze(0) + offset.unsqueeze(1) # (bsz, seq_len)
|
||||
position_ids = position_ids.clamp(max=cos.shape[2] - 1)
|
||||
|
||||
# Gather cos/sin for each position
|
||||
cos = cos[0, 0, position_ids, :] # (bsz, seq_len, head_dim)
|
||||
sin = sin[0, 0, position_ids, :]
|
||||
|
||||
# Expand for multi-heads
|
||||
cos = cos.unsqueeze(1) # (bsz, 1, seq_len, head_dim)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
x1 = x[..., :head_dim // 2]
|
||||
x2 = x[..., head_dim // 2:]
|
||||
|
||||
rotated = torch.cat((-x2, x1), dim=-1)
|
||||
x_rotated = (x * cos) + (rotated * sin)
|
||||
return x_rotated
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.qwen3_compatible = qwen3_compatible
|
||||
self.scale = nn.Parameter(torch.ones(emb_dim))
|
||||
self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
input_dtype = x.dtype
|
||||
|
||||
if self.qwen3_compatible:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
norm_x = x * torch.rsqrt(variance + self.eps)
|
||||
norm_x = norm_x * self.scale
|
||||
|
||||
if self.shift is not None:
|
||||
norm_x = norm_x + self.shift
|
||||
|
||||
return norm_x.to(input_dtype)
|
||||
24
pkg/llms_from_scratch/kv_cache_batched/utils.py
Normal file
24
pkg/llms_from_scratch/kv_cache_batched/utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
class KVCache:
|
||||
def __init__(self, n_layers, batch_size):
|
||||
self.cache = [
|
||||
[None for _ in range(batch_size)] for _ in range(n_layers)
|
||||
]
|
||||
|
||||
def get(self, layer_idx, batch_idx):
|
||||
return self.cache[layer_idx][batch_idx]
|
||||
|
||||
def update(self, layer_idx, batch_idx, value):
|
||||
self.cache[layer_idx][batch_idx] = value
|
||||
|
||||
def get_layer(self, layer_idx):
|
||||
return self.cache[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
for layer in self.cache:
|
||||
for i in range(len(layer)):
|
||||
layer[i] = None
|
||||
@ -15,8 +15,8 @@ from llms_from_scratch.qwen3 import (
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
# from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||
# from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
@ -172,7 +172,7 @@ def test_model_KV_noKV(qwen3_weights_path):
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids = torch.tensor([input_token_ids])
|
||||
|
||||
out_noKV = generate_text_simple_cached(
|
||||
out_KV = generate_text_simple_cached(
|
||||
model=model_KV,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
@ -185,7 +185,7 @@ def test_model_KV_noKV(qwen3_weights_path):
|
||||
model_noKV.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model_noKV.eval()
|
||||
|
||||
out_KV = generate_text_simple(
|
||||
out_noKV = generate_text_simple(
|
||||
model=model_noKV,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
@ -195,6 +195,69 @@ def test_model_KV_noKV(qwen3_weights_path):
|
||||
assert torch.equal(out_noKV, out_KV)
|
||||
|
||||
|
||||
def test_model_batched_KV(qwen3_weights_path):
|
||||
|
||||
torch.manual_seed(123)
|
||||
model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B)
|
||||
model_KV.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model_KV.eval()
|
||||
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="tokenizer-base.json",
|
||||
repo_id="rasbt/qwen3-from-scratch",
|
||||
add_generation_prompt=False,
|
||||
add_thinking=False
|
||||
)
|
||||
|
||||
# Batch size 1
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids = torch.tensor([input_token_ids])
|
||||
|
||||
out_KV = generate_text_simple_cached(
|
||||
model=model_KV,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
del model_KV
|
||||
|
||||
torch.manual_seed(123)
|
||||
model_KV_batched = Qwen3ModelKVBatched(QWEN_CONFIG_06_B)
|
||||
model_KV_batched.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model_KV_batched.eval()
|
||||
|
||||
out_KV_bs_1 = generate_text_simple_batched(
|
||||
model=model_KV_batched,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
|
||||
assert torch.equal(out_KV, out_KV_bs_1)
|
||||
|
||||
# Batch size 2
|
||||
|
||||
prompts = [
|
||||
"Give me a short introduction to large language models.",
|
||||
"Give me a short introduction to large language models."
|
||||
]
|
||||
tokenized_prompts = [tokenizer.encode(p) for p in prompts]
|
||||
max_len = max(len(t) for t in tokenized_prompts)
|
||||
padded_token_ids = [
|
||||
t + [tokenizer.pad_token_id] * (max_len - len(t)) for t in tokenized_prompts
|
||||
]
|
||||
input_tensor = torch.tensor(padded_token_ids)
|
||||
out_KV_bs_2 = generate_text_simple_batched(
|
||||
model=model_KV_batched,
|
||||
idx=input_tensor,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"],
|
||||
)
|
||||
assert torch.equal(out_KV.squeeze(0), out_KV_bs_2[0]), (out_KV.squeeze(0).shape, out_KV_bs_2[0].shape)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.16"
|
||||
version = "1.0.17"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user