Handle other Qwen3 tokenizer settings (#716)

This commit is contained in:
Sebastian Raschka 2025-06-30 17:49:51 -05:00 committed by GitHub
parent 4e61dc4224
commit 0405b0c8e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 25 deletions

View File

@ -410,15 +410,11 @@ def load_weights_into_qwen(model, param_config, params):
class Qwen3Tokenizer():
def __init__(self, tokenizer_file_path="tokenizer.json",
repo_id=None, add_generation_prompt=False, add_thinking=False):
repo_id=None, apply_chat_template=True,
add_generation_prompt=False, add_thinking=False):
from tokenizers import Tokenizer
self.tokenizer_file_path = tokenizer_file_path
if add_generation_prompt != add_thinking:
raise ValueError(
"Only add_generation_prompt==add_thinking settings are currently supported"
)
self.apply_chat_template = apply_chat_template
self.add_generation_prompt = add_generation_prompt
self.add_thinking = add_thinking
@ -432,14 +428,15 @@ class Qwen3Tokenizer():
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
def encode(self, prompt):
messages = [
{"role": "user", "content": prompt}
]
formatted_prompt = self.format_qwen_chat(
messages,
add_generation_prompt=self.add_generation_prompt,
add_thinking=self.add_thinking
)
if self.apply_chat_template:
messages = [{"role": "user", "content": prompt}]
formatted_prompt = self.format_qwen_chat(
messages,
add_generation_prompt=self.add_generation_prompt,
add_thinking=self.add_thinking
)
else:
formatted_prompt = prompt
return self.tokenizer.encode(formatted_prompt).ids
def decode(self, token_ids):
@ -452,10 +449,10 @@ class Qwen3Tokenizer():
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
if add_generation_prompt:
prompt += "<|im_start|>assistant"
if not add_thinking:
prompt += "<|think>\n\n<|/think>\n\n"
if add_thinking:
prompt += "\n" # no <think> tags
else:
prompt += "\n"
prompt += "\n<think>\n\n</think>\n\n"
return prompt

View File

@ -117,12 +117,6 @@ def qwen3_weights_path(tmp_path_factory):
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
return
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
return
torch.manual_seed(123)
model = ModelClass(QWEN_CONFIG_06_B)
model.load_state_dict(torch.load(qwen3_weights_path))

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llms-from-scratch"
version = "1.0.15"
version = "1.0.16"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md"
requires-python = ">=3.10"