Sebastian Raschka 7bd263144e
Switch from urllib to requests to improve reliability (#867)
* Switch from urllib to requests to improve reliability

* Keep ruff linter-specific

* update

* update

* update
2025-10-07 15:22:59 -05:00

61 lines
1.7 KiB
Python

# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch02 import create_dataloader_v1
import os
import requests
import pytest
import torch
@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
def test_dataloader(tmp_path, file_name):
if not os.path.exists("the-verdict.txt"):
url = (
"https://raw.githubusercontent.com/rasbt/"
"LLMs-from-scratch/main/ch02/01_main-chapter-code/"
"the-verdict.txt"
)
file_path = "the-verdict.txt"
response = requests.get(url, timeout=30)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
with open("the-verdict.txt", "r", encoding="utf-8") as f:
raw_text = f.read()
vocab_size = 50257
output_dim = 256
context_length = 1024
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
batch_size = 8
max_length = 4
dataloader = create_dataloader_v1(
raw_text,
batch_size=batch_size,
max_length=max_length,
stride=max_length
)
for batch in dataloader:
x, y = batch
token_embeddings = token_embedding_layer(x)
pos_embeddings = pos_embedding_layer(torch.arange(max_length))
input_embeddings = token_embeddings + pos_embeddings
break
input_embeddings.shape == torch.Size([8, 4, 256])