From 14c054d36ce649dbe2ab14eee7b159f30983fa2a Mon Sep 17 00:00:00 2001 From: Daniel Kleine <53251018+d-kleine@users.noreply.github.com> Date: Sat, 21 Jun 2025 23:07:50 +0200 Subject: [PATCH] added pkg fixes (#676) Co-authored-by: Sebastian Raschka --- pkg/llms_from_scratch/ch07.py | 2 +- pkg/llms_from_scratch/llama3.py | 19 ++++--------------- pkg/llms_from_scratch/tests/test_llama3.py | 2 +- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/pkg/llms_from_scratch/ch07.py b/pkg/llms_from_scratch/ch07.py index 5a0946a..3d50572 100644 --- a/pkg/llms_from_scratch/ch07.py +++ b/pkg/llms_from_scratch/ch07.py @@ -9,7 +9,7 @@ import psutil import urllib import torch -import tqdm +from tqdm import tqdm from torch.utils.data import Dataset diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index 785e8af..21a03e0 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -309,22 +309,11 @@ class Llama3Tokenizer: special_tokens=self.special, ) - def encode(self, text, bos=False, eos=False, allowed_special=set()): - ids: list[int] = [] - - if bos: - ids.append(self.special_tokens["<|begin_of_text|>"]) - - # delegate to underlying tiktoken.Encoding.encode - ids.extend( - self.model.encode( - text, - allowed_special=allowed_special, - ) - ) + def encode(self, text, bos=False, eos=False): + ids = ([self.special["<|begin_of_text|>"]] if bos else []) \ + + self.model.encode(text) if eos: - ids.append(self.special_tokens["<|end_of_text|>"]) - + ids.append(self.special["<|end_of_text|>"]) return ids def decode(self, ids): diff --git a/pkg/llms_from_scratch/tests/test_llama3.py b/pkg/llms_from_scratch/tests/test_llama3.py index 7625d77..c4aa834 100644 --- a/pkg/llms_from_scratch/tests/test_llama3.py +++ b/pkg/llms_from_scratch/tests/test_llama3.py @@ -199,7 +199,7 @@ def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path): torch.manual_seed(123) model = ModelClass(LLAMA32_CONFIG_1B) - model.load_state_dict(torch.load(llama3_weights_path)) + model.load_state_dict(torch.load(llama3_weights_path, weights_only=True)) model.eval() start_context = "Llamas eat"