mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-10 09:43:05 +00:00
55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
|
# Source for "Build a Large Language Model From Scratch"
|
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
|
|
|
from llms_from_scratch.ch02 import create_dataloader_v1
|
|
|
|
import os
|
|
import urllib.request
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
|
|
@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
|
|
def test_dataloader(tmp_path, file_name):
|
|
|
|
if not os.path.exists("the-verdict.txt"):
|
|
url = ("https://raw.githubusercontent.com/rasbt/"
|
|
"LLMs-from-scratch/main/ch02/01_main-chapter-code/"
|
|
"the-verdict.txt")
|
|
file_path = "the-verdict.txt"
|
|
urllib.request.urlretrieve(url, file_path)
|
|
|
|
with open("the-verdict.txt", "r", encoding="utf-8") as f:
|
|
raw_text = f.read()
|
|
|
|
vocab_size = 50257
|
|
output_dim = 256
|
|
context_length = 1024
|
|
|
|
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
|
|
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
|
|
|
|
batch_size = 8
|
|
max_length = 4
|
|
dataloader = create_dataloader_v1(
|
|
raw_text,
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
stride=max_length
|
|
)
|
|
|
|
for batch in dataloader:
|
|
x, y = batch
|
|
|
|
token_embeddings = token_embedding_layer(x)
|
|
pos_embeddings = pos_embedding_layer(torch.arange(max_length))
|
|
|
|
input_embeddings = token_embeddings + pos_embeddings
|
|
|
|
break
|
|
|
|
input_embeddings.shape == torch.Size([8, 4, 256])
|