mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-02 10:50:30 +00:00
Handle other Qwen3 tokenizer settings (#716)
This commit is contained in:
parent
4e61dc4224
commit
0405b0c8e7
@ -410,15 +410,11 @@ def load_weights_into_qwen(model, param_config, params):
|
||||
|
||||
class Qwen3Tokenizer():
|
||||
def __init__(self, tokenizer_file_path="tokenizer.json",
|
||||
repo_id=None, add_generation_prompt=False, add_thinking=False):
|
||||
repo_id=None, apply_chat_template=True,
|
||||
add_generation_prompt=False, add_thinking=False):
|
||||
from tokenizers import Tokenizer
|
||||
self.tokenizer_file_path = tokenizer_file_path
|
||||
|
||||
if add_generation_prompt != add_thinking:
|
||||
raise ValueError(
|
||||
"Only add_generation_prompt==add_thinking settings are currently supported"
|
||||
)
|
||||
|
||||
self.apply_chat_template = apply_chat_template
|
||||
self.add_generation_prompt = add_generation_prompt
|
||||
self.add_thinking = add_thinking
|
||||
|
||||
@ -432,14 +428,15 @@ class Qwen3Tokenizer():
|
||||
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
|
||||
|
||||
def encode(self, prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.format_qwen_chat(
|
||||
messages,
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
add_thinking=self.add_thinking
|
||||
)
|
||||
if self.apply_chat_template:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
formatted_prompt = self.format_qwen_chat(
|
||||
messages,
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
add_thinking=self.add_thinking
|
||||
)
|
||||
else:
|
||||
formatted_prompt = prompt
|
||||
return self.tokenizer.encode(formatted_prompt).ids
|
||||
|
||||
def decode(self, token_ids):
|
||||
@ -452,10 +449,10 @@ class Qwen3Tokenizer():
|
||||
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
|
||||
if add_generation_prompt:
|
||||
prompt += "<|im_start|>assistant"
|
||||
if not add_thinking:
|
||||
prompt += "<|think>\n\n<|/think>\n\n"
|
||||
if add_thinking:
|
||||
prompt += "\n" # no <think> tags
|
||||
else:
|
||||
prompt += "\n"
|
||||
prompt += "\n<think>\n\n</think>\n\n"
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
@ -117,12 +117,6 @@ def qwen3_weights_path(tmp_path_factory):
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(QWEN_CONFIG_06_B)
|
||||
model.load_state_dict(torch.load(qwen3_weights_path))
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.15"
|
||||
version = "1.0.16"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user