diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index b6060ac..dbdc330 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -410,15 +410,11 @@ def load_weights_into_qwen(model, param_config, params): class Qwen3Tokenizer(): def __init__(self, tokenizer_file_path="tokenizer.json", - repo_id=None, add_generation_prompt=False, add_thinking=False): + repo_id=None, apply_chat_template=True, + add_generation_prompt=False, add_thinking=False): from tokenizers import Tokenizer self.tokenizer_file_path = tokenizer_file_path - - if add_generation_prompt != add_thinking: - raise ValueError( - "Only add_generation_prompt==add_thinking settings are currently supported" - ) - + self.apply_chat_template = apply_chat_template self.add_generation_prompt = add_generation_prompt self.add_thinking = add_thinking @@ -432,14 +428,15 @@ class Qwen3Tokenizer(): self.tokenizer = Tokenizer.from_file(tokenizer_file_path) def encode(self, prompt): - messages = [ - {"role": "user", "content": prompt} - ] - formatted_prompt = self.format_qwen_chat( - messages, - add_generation_prompt=self.add_generation_prompt, - add_thinking=self.add_thinking - ) + if self.apply_chat_template: + messages = [{"role": "user", "content": prompt}] + formatted_prompt = self.format_qwen_chat( + messages, + add_generation_prompt=self.add_generation_prompt, + add_thinking=self.add_thinking + ) + else: + formatted_prompt = prompt return self.tokenizer.encode(formatted_prompt).ids def decode(self, token_ids): @@ -452,10 +449,10 @@ class Qwen3Tokenizer(): prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" if add_generation_prompt: prompt += "<|im_start|>assistant" - if not add_thinking: - prompt += "<|think>\n\n<|/think>\n\n" + if add_thinking: + prompt += "\n" # no tags else: - prompt += "\n" + prompt += "\n\n\n\n\n" return prompt diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index e375db5..43ff4a6 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -117,12 +117,6 @@ def qwen3_weights_path(tmp_path_factory): @pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached]) def test_model_variants(ModelClass, qwen3_weights_path, generate_fn): - # Skip incompatible combinations - if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False): - return - if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False): - return - torch.manual_seed(123) model = ModelClass(QWEN_CONFIG_06_B) model.load_state_dict(torch.load(qwen3_weights_path)) diff --git a/pyproject.toml b/pyproject.toml index 7f0a4d4..d7471f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llms-from-scratch" -version = "1.0.15" +version = "1.0.16" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" readme = "README.md" requires-python = ">=3.10"