mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-02 02:41:00 +00:00
added pkg fixes (#676)
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
parent
bb57756444
commit
2a530b49fe
@ -9,7 +9,7 @@ import psutil
|
||||
import urllib
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
|
||||
@ -309,22 +309,11 @@ class Llama3Tokenizer:
|
||||
special_tokens=self.special,
|
||||
)
|
||||
|
||||
def encode(self, text, bos=False, eos=False, allowed_special=set()):
|
||||
ids: list[int] = []
|
||||
|
||||
if bos:
|
||||
ids.append(self.special_tokens["<|begin_of_text|>"])
|
||||
|
||||
# delegate to underlying tiktoken.Encoding.encode
|
||||
ids.extend(
|
||||
self.model.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
)
|
||||
)
|
||||
def encode(self, text, bos=False, eos=False):
|
||||
ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
|
||||
+ self.model.encode(text)
|
||||
if eos:
|
||||
ids.append(self.special_tokens["<|end_of_text|>"])
|
||||
|
||||
ids.append(self.special["<|end_of_text|>"])
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
|
||||
@ -199,7 +199,7 @@ def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(LLAMA32_CONFIG_1B)
|
||||
model.load_state_dict(torch.load(llama3_weights_path))
|
||||
model.load_state_dict(torch.load(llama3_weights_path, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
start_context = "Llamas eat"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user