mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-03 19:30:26 +00:00 
			
		
		
		
	added pkg fixes (#676)
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
		
							parent
							
								
									bb57756444
								
							
						
					
					
						commit
						2a530b49fe
					
				@ -9,7 +9,7 @@ import psutil
 | 
			
		||||
import urllib
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from torch.utils.data import Dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -309,22 +309,11 @@ class Llama3Tokenizer:
 | 
			
		||||
            special_tokens=self.special,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def encode(self, text, bos=False, eos=False, allowed_special=set()):
 | 
			
		||||
        ids: list[int] = []
 | 
			
		||||
 | 
			
		||||
        if bos:
 | 
			
		||||
            ids.append(self.special_tokens["<|begin_of_text|>"])
 | 
			
		||||
 | 
			
		||||
        # delegate to underlying tiktoken.Encoding.encode
 | 
			
		||||
        ids.extend(
 | 
			
		||||
            self.model.encode(
 | 
			
		||||
                text,
 | 
			
		||||
                allowed_special=allowed_special,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    def encode(self, text, bos=False, eos=False):
 | 
			
		||||
        ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
 | 
			
		||||
              + self.model.encode(text)
 | 
			
		||||
        if eos:
 | 
			
		||||
            ids.append(self.special_tokens["<|end_of_text|>"])
 | 
			
		||||
 | 
			
		||||
            ids.append(self.special["<|end_of_text|>"])
 | 
			
		||||
        return ids
 | 
			
		||||
 | 
			
		||||
    def decode(self, ids):
 | 
			
		||||
 | 
			
		||||
@ -199,7 +199,7 @@ def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
 | 
			
		||||
 | 
			
		||||
    torch.manual_seed(123)
 | 
			
		||||
    model = ModelClass(LLAMA32_CONFIG_1B)
 | 
			
		||||
    model.load_state_dict(torch.load(llama3_weights_path))
 | 
			
		||||
    model.load_state_dict(torch.load(llama3_weights_path, weights_only=True))
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    start_context = "Llamas eat"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user