From 14fa50dfc855aa6000e8da42b95324873f8b0598 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Wed, 9 Jul 2025 13:16:26 -0500 Subject: [PATCH] Add more sophisticated Qwen3 tokenizer (#729) --- ch05/11_qwen3/standalone-qwen3.ipynb | 15 ---- pkg/llms_from_scratch/llama3.py | 2 +- pkg/llms_from_scratch/qwen3.py | 96 ++++++++++++++--------- pkg/llms_from_scratch/tests/test_qwen3.py | 84 +++++++++++++++++++- 4 files changed, 142 insertions(+), 55 deletions(-) diff --git a/ch05/11_qwen3/standalone-qwen3.ipynb b/ch05/11_qwen3/standalone-qwen3.ipynb index 848cc60..5b30dd1 100644 --- a/ch05/11_qwen3/standalone-qwen3.ipynb +++ b/ch05/11_qwen3/standalone-qwen3.ipynb @@ -487,21 +487,6 @@ " \"dtype\": torch.bfloat16,\n", " } \n", "\n", - "elif CHOOSE_MODEL == \"8B\":\n", - " QWEN3_CONFIG = {\n", - " \"vocab_size\": 151_936,\n", - " \"context_length\": 40_960,\n", - " \"emb_dim\": 4096, # 60% larger than above\n", - " \"n_heads\": 32,\n", - " \"n_layers\": 36, # 26% larger than above\n", - " \"hidden_dim\": 12288,\n", - " \"head_dim\": 128,\n", - " \"qk_norm\": True,\n", - " \"n_kv_groups\": 8,\n", - " \"rope_base\": 1_000_000.0,\n", - " \"dtype\": torch.bfloat16,\n", - " } \n", - "\n", "elif CHOOSE_MODEL == \"14B\":\n", " QWEN3_CONFIG = {\n", " \"vocab_size\": 151_936,\n", diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index 88509e1..ddd4cde 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -64,7 +64,7 @@ class Llama3Model(nn.Module): self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) - # Reusuable utilities + # Reusable utilities cos, sin = compute_rope_params( head_dim=cfg["emb_dim"] // cfg["n_heads"], theta_base=cfg["rope_base"], diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index dbdc330..33cf047 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -5,6 +5,7 @@ import os import json +import re import urllib.request from pathlib import Path @@ -115,7 +116,7 @@ class Qwen3Model(nn.Module): self.final_norm = RMSNorm(cfg["emb_dim"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) - # Reusuable utilities + # Reusable utilities if cfg["head_dim"] is None: head_dim = cfg["emb_dim"] // cfg["n_heads"] else: @@ -408,52 +409,77 @@ def load_weights_into_qwen(model, param_config, params): model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") -class Qwen3Tokenizer(): - def __init__(self, tokenizer_file_path="tokenizer.json", - repo_id=None, apply_chat_template=True, - add_generation_prompt=False, add_thinking=False): +class Qwen3Tokenizer: + _SPECIALS = [ + "<|endoftext|>", + "<|im_start|>", "<|im_end|>", + "<|object_ref_start|>", "<|object_ref_end|>", + "<|box_start|>", "<|box_end|>", + "<|quad_start|>", "<|quad_end|>", + "<|vision_start|>", "<|vision_end|>", + "<|vision_pad|>", "<|image_pad|>", "<|video_pad|>", + ] + _SPLIT_RE = re.compile(r"(<\|[^>]+?\|>)") + + def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None, + apply_chat_template=True, add_generation_prompt=False, add_thinking=False): from tokenizers import Tokenizer - self.tokenizer_file_path = tokenizer_file_path + self.apply_chat_template = apply_chat_template self.add_generation_prompt = add_generation_prompt self.add_thinking = add_thinking - tokenizer_file_path_obj = Path(tokenizer_file_path) - if not tokenizer_file_path_obj.is_file() and repo_id is not None: - _ = download_from_huggingface( + tok_file = Path(tokenizer_file_path) + if not tok_file.is_file() and repo_id: + download_from_huggingface( repo_id=repo_id, - filename=str(tokenizer_file_path_obj.name), - local_dir=str(tokenizer_file_path_obj.parent.name) + filename=tok_file.name, + local_dir=str(tok_file.parent), ) - self.tokenizer = Tokenizer.from_file(tokenizer_file_path) + self._tok = Tokenizer.from_file(str(tok_file)) + self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS} - def encode(self, prompt): - if self.apply_chat_template: - messages = [{"role": "user", "content": prompt}] - formatted_prompt = self.format_qwen_chat( - messages, - add_generation_prompt=self.add_generation_prompt, - add_thinking=self.add_thinking - ) + self.pad_token_id = self._special_to_id.get("<|endoftext|>") + self.eos_token_id = self.pad_token_id + + if repo_id and "Base" not in repo_id: + eos_token = "<|im_end|>" else: - formatted_prompt = prompt - return self.tokenizer.encode(formatted_prompt).ids + eos_token = "<|endoftext|>" + if eos_token in self._special_to_id: + self.eos_token_id = self._special_to_id[eos_token] - def decode(self, token_ids): - return self.tokenizer.decode(token_ids, skip_special_tokens=False) + def encode(self, text, chat_wrapped=None): + if chat_wrapped is None: + chat_wrapped = self.apply_chat_template - @staticmethod - def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False): - prompt = "" - for msg in messages: - prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" - if add_generation_prompt: - prompt += "<|im_start|>assistant" - if add_thinking: - prompt += "\n" # no tags + stripped = text.strip() + if stripped in self._special_to_id and "\n" not in stripped: + return [self._special_to_id[stripped]] + + if chat_wrapped: + text = self._wrap_chat(text) + + ids = [] + for part in filter(None, self._SPLIT_RE.split(text)): + if part in self._special_to_id: + ids.append(self._special_to_id[part]) else: - prompt += "\n\n\n\n\n" - return prompt + ids.extend(self._tok.encode(part).ids) + return ids + + def decode(self, ids): + return self._tok.decode(ids, skip_special_tokens=False) + + def _wrap_chat(self, user_msg): + s = f"<|im_start|>user\n{user_msg}<|im_end|>\n" + if self.add_generation_prompt: + s += "<|im_start|>assistant" + if self.add_thinking: + s += "\n" + else: + s += "\n\n\n\n\n" + return s def download_from_huggingface(repo_id, filename, local_dir, revision="main"): diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 8ca6f14..50f308f 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -15,6 +15,8 @@ from llms_from_scratch.qwen3 import ( from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached +# from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched +# from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched import importlib import pytest @@ -113,7 +115,7 @@ def qwen3_weights_path(tmp_path_factory): @pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV]) -@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached]) +@pytest.mark.parametrize("generate_fn", [generate_text_simple]) def test_model_variants(ModelClass, qwen3_weights_path, generate_fn): torch.manual_seed(123) @@ -137,7 +139,7 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn): print("Encoded input text:", input_token_ids) print("encoded_tensor.shape:", input_token_ids.shape) - out = generate_text_simple( + out = generate_fn( model=model, idx=input_token_ids, max_new_tokens=5, @@ -152,6 +154,47 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn): assert torch.equal(expect, out) +def test_model_KV_noKV(qwen3_weights_path): + + torch.manual_seed(123) + model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B) + model_KV.load_state_dict(torch.load(qwen3_weights_path)) + model_KV.eval() + + tokenizer = Qwen3Tokenizer( + tokenizer_file_path="tokenizer-base.json", + repo_id="rasbt/qwen3-from-scratch", + add_generation_prompt=False, + add_thinking=False + ) + + prompt = "Give me a short introduction to large language models." + input_token_ids = tokenizer.encode(prompt) + input_token_ids = torch.tensor([input_token_ids]) + + out_noKV = generate_text_simple_cached( + model=model_KV, + idx=input_token_ids, + max_new_tokens=5, + context_size=QWEN_CONFIG_06_B["context_length"] + ) + del model_KV + + torch.manual_seed(123) + model_noKV = Qwen3Model(QWEN_CONFIG_06_B) + model_noKV.load_state_dict(torch.load(qwen3_weights_path)) + model_noKV.eval() + + out_KV = generate_text_simple( + model=model_noKV, + idx=input_token_ids, + max_new_tokens=5, + context_size=QWEN_CONFIG_06_B["context_length"] + ) + + assert torch.equal(out_noKV, out_KV) + + def test_rmsnorm_equivalence(): torch.manual_seed(42) @@ -177,13 +220,16 @@ def test_rmsnorm_equivalence(): @pytest.mark.skipif(not transformers_installed, reason="transformers not installed") def test_tokenizer_equivalence(): from transformers import AutoTokenizer - repo_id = "Qwen/Qwen3-0.6B" - tokenizer_ref = AutoTokenizer.from_pretrained(repo_id) + prompt = "Give me a short introduction to large language models." messages = [ {"role": "user", "content": prompt}, ] + # Reasoning model tokenizer + repo_id = "Qwen/Qwen3-0.6B" + tokenizer_ref = AutoTokenizer.from_pretrained(repo_id) + for states in ((True, True), (False, False)): tokenizer = Qwen3Tokenizer( tokenizer_file_path="Qwen3-0.6B/tokenizer.json", @@ -203,3 +249,33 @@ def test_tokenizer_equivalence(): output_text = tokenizer.decode(input_token_ids) out_text_ref = tokenizer_ref.decode(input_token_ids_ref) assert output_text == out_text_ref, states + + assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id + assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id + + # Base model tokenizer + repo_id = "Qwen/Qwen3-0.6B-Base" + tokenizer_ref = AutoTokenizer.from_pretrained(repo_id) + + for states in ((True, True), (False, False)): + tokenizer = Qwen3Tokenizer( + tokenizer_file_path="Qwen3-0.6B-Base/tokenizer.json", + repo_id=repo_id, + add_generation_prompt=states[0], + add_thinking=states[1] + ) + input_token_ids = tokenizer.encode(prompt) + input_token_ids_ref = tokenizer_ref.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=states[0], + enable_thinking=states[1], + ) + assert input_token_ids == input_token_ids_ref, states + + output_text = tokenizer.decode(input_token_ids) + out_text_ref = tokenizer_ref.decode(input_token_ids_ref) + assert output_text == out_text_ref, states + + assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id + assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id