added pkg fixes (#676)

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
Daniel Kleine 2025-06-21 23:07:50 +02:00 committed by GitHub
parent bb57756444
commit 2a530b49fe
3 changed files with 6 additions and 17 deletions

View File

@ -9,7 +9,7 @@ import psutil
import urllib
import torch
import tqdm
from tqdm import tqdm
from torch.utils.data import Dataset

View File

@ -309,22 +309,11 @@ class Llama3Tokenizer:
special_tokens=self.special,
)
def encode(self, text, bos=False, eos=False, allowed_special=set()):
ids: list[int] = []
if bos:
ids.append(self.special_tokens["<|begin_of_text|>"])
# delegate to underlying tiktoken.Encoding.encode
ids.extend(
self.model.encode(
text,
allowed_special=allowed_special,
)
)
def encode(self, text, bos=False, eos=False):
ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
+ self.model.encode(text)
if eos:
ids.append(self.special_tokens["<|end_of_text|>"])
ids.append(self.special["<|end_of_text|>"])
return ids
def decode(self, ids):

View File

@ -199,7 +199,7 @@ def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B)
model.load_state_dict(torch.load(llama3_weights_path))
model.load_state_dict(torch.load(llama3_weights_path, weights_only=True))
model.eval()
start_context = "Llamas eat"