mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 01:41:26 +00:00 
			
		
		
		
	
		
			
	
	
		
			55 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			55 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | ||
|  | # Source for "Build a Large Language Model From Scratch" | ||
|  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | ||
|  | # Code: https://github.com/rasbt/LLMs-from-scratch | ||
|  | 
 | ||
|  | from llms_from_scratch.ch02 import create_dataloader_v1 | ||
|  | 
 | ||
|  | import os | ||
|  | import urllib.request | ||
|  | 
 | ||
|  | import pytest | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.parametrize("file_name", ["the-verdict.txt"]) | ||
|  | def test_dataloader(tmp_path, file_name): | ||
|  | 
 | ||
|  |     if not os.path.exists("the-verdict.txt"): | ||
|  |         url = ("https://raw.githubusercontent.com/rasbt/" | ||
|  |                "LLMs-from-scratch/main/ch02/01_main-chapter-code/" | ||
|  |                "the-verdict.txt") | ||
|  |         file_path = "the-verdict.txt" | ||
|  |         urllib.request.urlretrieve(url, file_path) | ||
|  | 
 | ||
|  |     with open("the-verdict.txt", "r", encoding="utf-8") as f: | ||
|  |         raw_text = f.read() | ||
|  | 
 | ||
|  |     vocab_size = 50257 | ||
|  |     output_dim = 256 | ||
|  |     context_length = 1024 | ||
|  | 
 | ||
|  |     token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim) | ||
|  |     pos_embedding_layer = torch.nn.Embedding(context_length, output_dim) | ||
|  | 
 | ||
|  |     batch_size = 8 | ||
|  |     max_length = 4 | ||
|  |     dataloader = create_dataloader_v1( | ||
|  |         raw_text, | ||
|  |         batch_size=batch_size, | ||
|  |         max_length=max_length, | ||
|  |         stride=max_length | ||
|  |     ) | ||
|  | 
 | ||
|  |     for batch in dataloader: | ||
|  |         x, y = batch | ||
|  | 
 | ||
|  |         token_embeddings = token_embedding_layer(x) | ||
|  |         pos_embeddings = pos_embedding_layer(torch.arange(max_length)) | ||
|  | 
 | ||
|  |         input_embeddings = token_embeddings + pos_embeddings | ||
|  | 
 | ||
|  |         break | ||
|  | 
 | ||
|  |     input_embeddings.shape == torch.Size([8, 4, 256]) |