diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index 3e9f726..11542a6 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -536,14 +536,14 @@ class Qwen3Tokenizer: self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS} self.pad_token_id = self._special_to_id.get("<|endoftext|>") - self.eos_token_id = self.pad_token_id - if repo_id and "Base" not in repo_id: - eos_token = "<|im_end|>" + # Match HF behavior: chat model → <|im_end|>, base model → <|endoftext|> + fname = tok_file.name.lower() + if "base" in fname and "reasoning" not in fname: + self.eos_token = "<|endoftext|>" else: - eos_token = "<|endoftext|>" - if eos_token in self._special_to_id: - self.eos_token_id = self._special_to_id[eos_token] + self.eos_token = "<|im_end|>" + self.eos_token_id = self._special_to_id.get(self.eos_token) def encode(self, text, chat_wrapped=None): if chat_wrapped is None: