mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-28 18:40:01 +00:00
Add more sophisticated Qwen3 tokenizer (#729)
This commit is contained in:
parent
f596aab0cb
commit
14fa50dfc8
@ -487,21 +487,6 @@
|
|||||||
" \"dtype\": torch.bfloat16,\n",
|
" \"dtype\": torch.bfloat16,\n",
|
||||||
" } \n",
|
" } \n",
|
||||||
"\n",
|
"\n",
|
||||||
"elif CHOOSE_MODEL == \"8B\":\n",
|
|
||||||
" QWEN3_CONFIG = {\n",
|
|
||||||
" \"vocab_size\": 151_936,\n",
|
|
||||||
" \"context_length\": 40_960,\n",
|
|
||||||
" \"emb_dim\": 4096, # 60% larger than above\n",
|
|
||||||
" \"n_heads\": 32,\n",
|
|
||||||
" \"n_layers\": 36, # 26% larger than above\n",
|
|
||||||
" \"hidden_dim\": 12288,\n",
|
|
||||||
" \"head_dim\": 128,\n",
|
|
||||||
" \"qk_norm\": True,\n",
|
|
||||||
" \"n_kv_groups\": 8,\n",
|
|
||||||
" \"rope_base\": 1_000_000.0,\n",
|
|
||||||
" \"dtype\": torch.bfloat16,\n",
|
|
||||||
" } \n",
|
|
||||||
"\n",
|
|
||||||
"elif CHOOSE_MODEL == \"14B\":\n",
|
"elif CHOOSE_MODEL == \"14B\":\n",
|
||||||
" QWEN3_CONFIG = {\n",
|
" QWEN3_CONFIG = {\n",
|
||||||
" \"vocab_size\": 151_936,\n",
|
" \"vocab_size\": 151_936,\n",
|
||||||
|
@ -64,7 +64,7 @@ class Llama3Model(nn.Module):
|
|||||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||||
|
|
||||||
# Reusuable utilities
|
# Reusable utilities
|
||||||
cos, sin = compute_rope_params(
|
cos, sin = compute_rope_params(
|
||||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||||
theta_base=cfg["rope_base"],
|
theta_base=cfg["rope_base"],
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -115,7 +116,7 @@ class Qwen3Model(nn.Module):
|
|||||||
self.final_norm = RMSNorm(cfg["emb_dim"])
|
self.final_norm = RMSNorm(cfg["emb_dim"])
|
||||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||||
|
|
||||||
# Reusuable utilities
|
# Reusable utilities
|
||||||
if cfg["head_dim"] is None:
|
if cfg["head_dim"] is None:
|
||||||
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
||||||
else:
|
else:
|
||||||
@ -408,52 +409,77 @@ def load_weights_into_qwen(model, param_config, params):
|
|||||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Tokenizer():
|
class Qwen3Tokenizer:
|
||||||
def __init__(self, tokenizer_file_path="tokenizer.json",
|
_SPECIALS = [
|
||||||
repo_id=None, apply_chat_template=True,
|
"<|endoftext|>",
|
||||||
add_generation_prompt=False, add_thinking=False):
|
"<|im_start|>", "<|im_end|>",
|
||||||
|
"<|object_ref_start|>", "<|object_ref_end|>",
|
||||||
|
"<|box_start|>", "<|box_end|>",
|
||||||
|
"<|quad_start|>", "<|quad_end|>",
|
||||||
|
"<|vision_start|>", "<|vision_end|>",
|
||||||
|
"<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
|
||||||
|
]
|
||||||
|
_SPLIT_RE = re.compile(r"(<\|[^>]+?\|>)")
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None,
|
||||||
|
apply_chat_template=True, add_generation_prompt=False, add_thinking=False):
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
self.tokenizer_file_path = tokenizer_file_path
|
|
||||||
self.apply_chat_template = apply_chat_template
|
self.apply_chat_template = apply_chat_template
|
||||||
self.add_generation_prompt = add_generation_prompt
|
self.add_generation_prompt = add_generation_prompt
|
||||||
self.add_thinking = add_thinking
|
self.add_thinking = add_thinking
|
||||||
|
|
||||||
tokenizer_file_path_obj = Path(tokenizer_file_path)
|
tok_file = Path(tokenizer_file_path)
|
||||||
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
|
if not tok_file.is_file() and repo_id:
|
||||||
_ = download_from_huggingface(
|
download_from_huggingface(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
filename=str(tokenizer_file_path_obj.name),
|
filename=tok_file.name,
|
||||||
local_dir=str(tokenizer_file_path_obj.parent.name)
|
local_dir=str(tok_file.parent),
|
||||||
)
|
)
|
||||||
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
|
self._tok = Tokenizer.from_file(str(tok_file))
|
||||||
|
self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS}
|
||||||
|
|
||||||
def encode(self, prompt):
|
self.pad_token_id = self._special_to_id.get("<|endoftext|>")
|
||||||
if self.apply_chat_template:
|
self.eos_token_id = self.pad_token_id
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
formatted_prompt = self.format_qwen_chat(
|
if repo_id and "Base" not in repo_id:
|
||||||
messages,
|
eos_token = "<|im_end|>"
|
||||||
add_generation_prompt=self.add_generation_prompt,
|
|
||||||
add_thinking=self.add_thinking
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
formatted_prompt = prompt
|
eos_token = "<|endoftext|>"
|
||||||
return self.tokenizer.encode(formatted_prompt).ids
|
if eos_token in self._special_to_id:
|
||||||
|
self.eos_token_id = self._special_to_id[eos_token]
|
||||||
|
|
||||||
def decode(self, token_ids):
|
def encode(self, text, chat_wrapped=None):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
if chat_wrapped is None:
|
||||||
|
chat_wrapped = self.apply_chat_template
|
||||||
|
|
||||||
@staticmethod
|
stripped = text.strip()
|
||||||
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
|
if stripped in self._special_to_id and "\n" not in stripped:
|
||||||
prompt = ""
|
return [self._special_to_id[stripped]]
|
||||||
for msg in messages:
|
|
||||||
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
|
if chat_wrapped:
|
||||||
if add_generation_prompt:
|
text = self._wrap_chat(text)
|
||||||
prompt += "<|im_start|>assistant"
|
|
||||||
if add_thinking:
|
ids = []
|
||||||
prompt += "\n" # no <think> tags
|
for part in filter(None, self._SPLIT_RE.split(text)):
|
||||||
|
if part in self._special_to_id:
|
||||||
|
ids.append(self._special_to_id[part])
|
||||||
else:
|
else:
|
||||||
prompt += "\n<think>\n\n</think>\n\n"
|
ids.extend(self._tok.encode(part).ids)
|
||||||
return prompt
|
return ids
|
||||||
|
|
||||||
|
def decode(self, ids):
|
||||||
|
return self._tok.decode(ids, skip_special_tokens=False)
|
||||||
|
|
||||||
|
def _wrap_chat(self, user_msg):
|
||||||
|
s = f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
||||||
|
if self.add_generation_prompt:
|
||||||
|
s += "<|im_start|>assistant"
|
||||||
|
if self.add_thinking:
|
||||||
|
s += "\n"
|
||||||
|
else:
|
||||||
|
s += "\n<think>\n\n</think>\n\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|
||||||
|
@ -15,6 +15,8 @@ from llms_from_scratch.qwen3 import (
|
|||||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||||
|
|
||||||
|
# from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||||
|
# from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import pytest
|
import pytest
|
||||||
@ -113,7 +115,7 @@ def qwen3_weights_path(tmp_path_factory):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
|
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
|
||||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
@pytest.mark.parametrize("generate_fn", [generate_text_simple])
|
||||||
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
||||||
|
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
@ -137,7 +139,7 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
|||||||
print("Encoded input text:", input_token_ids)
|
print("Encoded input text:", input_token_ids)
|
||||||
print("encoded_tensor.shape:", input_token_ids.shape)
|
print("encoded_tensor.shape:", input_token_ids.shape)
|
||||||
|
|
||||||
out = generate_text_simple(
|
out = generate_fn(
|
||||||
model=model,
|
model=model,
|
||||||
idx=input_token_ids,
|
idx=input_token_ids,
|
||||||
max_new_tokens=5,
|
max_new_tokens=5,
|
||||||
@ -152,6 +154,47 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
|||||||
assert torch.equal(expect, out)
|
assert torch.equal(expect, out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_KV_noKV(qwen3_weights_path):
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B)
|
||||||
|
model_KV.load_state_dict(torch.load(qwen3_weights_path))
|
||||||
|
model_KV.eval()
|
||||||
|
|
||||||
|
tokenizer = Qwen3Tokenizer(
|
||||||
|
tokenizer_file_path="tokenizer-base.json",
|
||||||
|
repo_id="rasbt/qwen3-from-scratch",
|
||||||
|
add_generation_prompt=False,
|
||||||
|
add_thinking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "Give me a short introduction to large language models."
|
||||||
|
input_token_ids = tokenizer.encode(prompt)
|
||||||
|
input_token_ids = torch.tensor([input_token_ids])
|
||||||
|
|
||||||
|
out_noKV = generate_text_simple_cached(
|
||||||
|
model=model_KV,
|
||||||
|
idx=input_token_ids,
|
||||||
|
max_new_tokens=5,
|
||||||
|
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||||
|
)
|
||||||
|
del model_KV
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model_noKV = Qwen3Model(QWEN_CONFIG_06_B)
|
||||||
|
model_noKV.load_state_dict(torch.load(qwen3_weights_path))
|
||||||
|
model_noKV.eval()
|
||||||
|
|
||||||
|
out_KV = generate_text_simple(
|
||||||
|
model=model_noKV,
|
||||||
|
idx=input_token_ids,
|
||||||
|
max_new_tokens=5,
|
||||||
|
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.equal(out_noKV, out_KV)
|
||||||
|
|
||||||
|
|
||||||
def test_rmsnorm_equivalence():
|
def test_rmsnorm_equivalence():
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
@ -177,13 +220,16 @@ def test_rmsnorm_equivalence():
|
|||||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||||
def test_tokenizer_equivalence():
|
def test_tokenizer_equivalence():
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
repo_id = "Qwen/Qwen3-0.6B"
|
|
||||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
|
||||||
prompt = "Give me a short introduction to large language models."
|
prompt = "Give me a short introduction to large language models."
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Reasoning model tokenizer
|
||||||
|
repo_id = "Qwen/Qwen3-0.6B"
|
||||||
|
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||||
|
|
||||||
for states in ((True, True), (False, False)):
|
for states in ((True, True), (False, False)):
|
||||||
tokenizer = Qwen3Tokenizer(
|
tokenizer = Qwen3Tokenizer(
|
||||||
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
||||||
@ -203,3 +249,33 @@ def test_tokenizer_equivalence():
|
|||||||
output_text = tokenizer.decode(input_token_ids)
|
output_text = tokenizer.decode(input_token_ids)
|
||||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||||
assert output_text == out_text_ref, states
|
assert output_text == out_text_ref, states
|
||||||
|
|
||||||
|
assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
|
||||||
|
assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
|
||||||
|
|
||||||
|
# Base model tokenizer
|
||||||
|
repo_id = "Qwen/Qwen3-0.6B-Base"
|
||||||
|
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||||
|
|
||||||
|
for states in ((True, True), (False, False)):
|
||||||
|
tokenizer = Qwen3Tokenizer(
|
||||||
|
tokenizer_file_path="Qwen3-0.6B-Base/tokenizer.json",
|
||||||
|
repo_id=repo_id,
|
||||||
|
add_generation_prompt=states[0],
|
||||||
|
add_thinking=states[1]
|
||||||
|
)
|
||||||
|
input_token_ids = tokenizer.encode(prompt)
|
||||||
|
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=states[0],
|
||||||
|
enable_thinking=states[1],
|
||||||
|
)
|
||||||
|
assert input_token_ids == input_token_ids_ref, states
|
||||||
|
|
||||||
|
output_text = tokenizer.decode(input_token_ids)
|
||||||
|
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||||
|
assert output_text == out_text_ref, states
|
||||||
|
|
||||||
|
assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
|
||||||
|
assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user