mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-01-08 21:34:58 +00:00
Update Qwen3 tokenizer test (#727)
* Update Qwen3 tokenizer test * add tokenizers to dev dependencies * add tokenizers to dev dependencies
This commit is contained in:
parent
80c1bb2cf4
commit
b5bd8d2de2
1
.github/workflows/check-links.yml
vendored
1
.github/workflows/check-links.yml
vendored
@ -23,6 +23,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv sync --dev
|
||||
uv add pytest-ruff pytest-check-links
|
||||
|
||||
- name: Check links
|
||||
|
||||
@ -18,7 +18,6 @@ from llms_from_scratch.kv_cache.generate import generate_text_simple as generate
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -102,8 +101,8 @@ def test_rope():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen3_weights_path(tmp_path_factory):
|
||||
"""Creates and saves a deterministic Llama3 model for testing."""
|
||||
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
|
||||
"""Creates and saves a deterministic model for testing."""
|
||||
path = tmp_path_factory.mktemp("models") / "qwen3_test_weights.pt"
|
||||
|
||||
if not path.exists():
|
||||
torch.manual_seed(123)
|
||||
@ -122,26 +121,33 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
||||
model.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model.eval()
|
||||
|
||||
start_context = "Llamas eat"
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="tokenizer-base.json",
|
||||
repo_id="rasbt/qwen3-from-scratch",
|
||||
add_generation_prompt=False,
|
||||
add_thinking=False
|
||||
)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
encoded = tokenizer.encode(start_context)
|
||||
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids = torch.tensor([input_token_ids])
|
||||
|
||||
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
|
||||
print("\nInput text:", start_context)
|
||||
print("Encoded input text:", encoded)
|
||||
print("encoded_tensor.shape:", encoded_tensor.shape)
|
||||
print("\nInput text:", prompt)
|
||||
print("Encoded input text:", input_token_ids)
|
||||
print("encoded_tensor.shape:", input_token_ids.shape)
|
||||
|
||||
out = generate_text_simple(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
print("Encoded output text:", out)
|
||||
expect = torch.tensor([
|
||||
[43, 2543, 292, 4483, 115206, 459, 43010, 104223, 55553]
|
||||
[151644, 872, 198, 35127, 752, 264, 2805, 16800, 311,
|
||||
3460, 4128, 4119, 13, 151645, 198, 112120, 83942, 60483,
|
||||
102652, 7414]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ dev = [
|
||||
"build>=1.2.2.post1",
|
||||
"llms-from-scratch",
|
||||
"twine>=6.1.0",
|
||||
"tokenizers>=0.21.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user