added pkg fixes (#676)

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
Daniel Kleine 2025-06-21 23:07:50 +02:00 committed by GitHub
parent bb57756444
commit 2a530b49fe
3 changed files with 6 additions and 17 deletions

View File

@ -9,7 +9,7 @@ import psutil
import urllib import urllib
import torch import torch
import tqdm from tqdm import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset

View File

@ -309,22 +309,11 @@ class Llama3Tokenizer:
special_tokens=self.special, special_tokens=self.special,
) )
def encode(self, text, bos=False, eos=False, allowed_special=set()): def encode(self, text, bos=False, eos=False):
ids: list[int] = [] ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
+ self.model.encode(text)
if bos:
ids.append(self.special_tokens["<|begin_of_text|>"])
# delegate to underlying tiktoken.Encoding.encode
ids.extend(
self.model.encode(
text,
allowed_special=allowed_special,
)
)
if eos: if eos:
ids.append(self.special_tokens["<|end_of_text|>"]) ids.append(self.special["<|end_of_text|>"])
return ids return ids
def decode(self, ids): def decode(self, ids):

View File

@ -199,7 +199,7 @@ def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
torch.manual_seed(123) torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B) model = ModelClass(LLAMA32_CONFIG_1B)
model.load_state_dict(torch.load(llama3_weights_path)) model.load_state_dict(torch.load(llama3_weights_path, weights_only=True))
model.eval() model.eval()
start_context = "Llamas eat" start_context = "Llamas eat"