casinca 42c130623b
Qwen3Tokenizer fix for Qwen3 Base models and generation mismatch with HF (#828)
* prevent `self.apply_chat_template` being applied for base Qwen models

* - added no chat template comparison in `test_chat_wrap_and_equivalence`
- removed duplicate comparison

* Revert "- added no chat template comparison in `test_chat_wrap_and_equivalence`"

This reverts commit 3a5ee8cfa19aa7e4874cd5f35171098be760b05f.

* Revert "prevent `self.apply_chat_template` being applied for base Qwen models"

This reverts commit df504397a8957886c6d6d808615545e37ceffcad.

* copied `download_file` in `utils` from https://github.com/rasbt/reasoning-from-scratch/blob/main/reasoning_from_scratch/utils.py

* added copy of test `def test_tokenizer_equivalence()` from `reasoning-from-scratch` in `test_qwen3.py`

* removed duplicate code fragment in`test_chat_wrap_and_equivalence`

* use apply_chat_template

* add toggle for instruct model

* Update tokenizer usage

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
2025-09-17 08:14:11 -05:00

146 lines
4.4 KiB
Python

# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# Internal utility functions (not intended for public use)
import ast
import re
import types
from pathlib import Path
import urllib.request
import urllib.parse
import nbformat
def _extract_imports(src: str):
out = []
try:
tree = ast.parse(src)
except SyntaxError:
return out
for node in tree.body:
if isinstance(node, ast.Import):
parts = []
for n in node.names:
parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
out.append("import " + ", ".join(parts))
elif isinstance(node, ast.ImportFrom):
module = node.module or ""
parts = []
for n in node.names:
parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
level = "." * node.level if getattr(node, "level", 0) else ""
out.append(f"from {level}{module} import " + ", ".join(parts))
return out
def _extract_defs_and_classes_from_code(src):
lines = src.splitlines()
kept = []
i = 0
while i < len(lines):
line = lines[i]
stripped = line.lstrip()
if stripped.startswith("@"):
j = i + 1
while j < len(lines) and not lines[j].strip():
j += 1
if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
kept.append(line)
i += 1
continue
if stripped.startswith("def ") or stripped.startswith("class "):
kept.append(line)
base_indent = len(line) - len(stripped)
i += 1
while i < len(lines):
nxt = lines[i]
if nxt.strip() == "":
kept.append(nxt)
i += 1
continue
indent = len(nxt) - len(nxt.lstrip())
if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
break
kept.append(nxt)
i += 1
continue
i += 1
code = "\n".join(kept)
# General rule:
# replace functions defined like `def load_weights_into_xxx(ClassName, ...`
# with `def load_weights_into_xxx(model, ...`
code = re.sub(
r"(def\s+load_weights_into_\w+\s*\()\s*\w+\s*,",
r"\1model,",
code
)
return code
def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extra_globals=None):
nb_path = Path(nb_dir_or_path)
if notebook_name is not None:
nb_file = nb_path / notebook_name if nb_path.is_dir() else nb_path
else:
nb_file = nb_path
if not nb_file.exists():
raise FileNotFoundError(f"Notebook not found: {nb_file}")
nb = nbformat.read(nb_file, as_version=4)
import_lines = []
seen = set()
for cell in nb.cells:
if cell.cell_type == "code":
for line in _extract_imports(cell.source):
if line not in seen:
import_lines.append(line)
seen.add(line)
for required in ("import torch", "import torch.nn as nn"):
if required not in seen:
import_lines.append(required)
seen.add(required)
pieces = []
for cell in nb.cells:
if cell.cell_type == "code":
pieces.append(_extract_defs_and_classes_from_code(cell.source))
src = "\n\n".join(import_lines + pieces)
mod_name = nb_file.stem.replace("-", "_").replace(" ", "_") or "notebook_defs"
mod = types.ModuleType(mod_name)
if extra_globals:
mod.__dict__.update(extra_globals)
exec(src, mod.__dict__)
return mod
def download_file(url, out_dir="."):
"""Simple file download utility for tests."""
from pathlib import Path
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = Path(urllib.parse.urlparse(url).path).name
dest = out_dir / filename
if dest.exists():
return dest
try:
with urllib.request.urlopen(url) as response:
with open(dest, 'wb') as f:
f.write(response.read())
return dest
except Exception as e:
raise RuntimeError(f"Failed to download {url}: {e}")