mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-28 02:20:13 +00:00
fixed num_workers (#229)
* fixed num_workers * ch06 & ch07: added num_workers to create_dataloader_v1
This commit is contained in:
parent
c935725a26
commit
73be1c592f
@ -50,7 +50,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -1346,7 +1346,7 @@
|
||||
" batch_size=batch_size,\n",
|
||||
" shuffle=shuffle,\n",
|
||||
" drop_last=drop_last,\n",
|
||||
" num_workers=0\n",
|
||||
" num_workers=num_workers\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return dataloader"
|
||||
|
@ -82,7 +82,7 @@
|
||||
"\n",
|
||||
" # Create dataloader\n",
|
||||
" dataloader = DataLoader(\n",
|
||||
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)\n",
|
||||
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)\n",
|
||||
"\n",
|
||||
" return dataloader\n",
|
||||
"\n",
|
||||
|
@ -128,7 +128,7 @@
|
||||
" batch_size=batch_size,\n",
|
||||
" shuffle=shuffle,\n",
|
||||
" drop_last=drop_last,\n",
|
||||
" num_workers=0\n",
|
||||
" num_workers=num_workers\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return dataloader"
|
||||
|
@ -13,7 +13,7 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride, num_workers=0):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
@ -44,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -41,7 +41,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -49,7 +49,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -49,7 +49,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -49,7 +49,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -44,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -49,7 +49,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -41,7 +41,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@ -50,7 +50,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -41,7 +41,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@ -50,7 +50,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -42,7 +42,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@ -51,7 +51,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
@ -45,7 +45,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@ -54,7 +54,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user