mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-17 21:22:09 +00:00
reorg files and make standalone download file
This commit is contained in:
parent
3ad442ee90
commit
ab1e56a323
@ -34,7 +34,7 @@
|
||||
"matplotlib version: 3.8.2\n",
|
||||
"numpy version: 1.26.0\n",
|
||||
"tiktoken version: 0.5.1\n",
|
||||
"torch version: 2.2.1\n"
|
||||
"torch version: 2.2.2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1348,9 +1348,8 @@
|
||||
" plt.savefig(\"loss-plot.pdf\")\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
|
||||
"plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)\n"
|
||||
"plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1413,7 +1412,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 42,
|
||||
"id": "2734cee0-f6f9-42d5-b71c-fa7e0ef28b6d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1484,7 +1483,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 43,
|
||||
"id": "01a5ce39-3dc8-4c35-96bc-6410a1e42412",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1526,7 +1525,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 44,
|
||||
"id": "6400572f-b3c8-49e2-95bc-433e55c5b3a1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1546,7 +1545,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 45,
|
||||
"id": "b23b863e-252a-403c-b5b1-62bc0a42319f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1598,7 +1597,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 46,
|
||||
"id": "0759e4c8-5362-467c-bec6-b0a19d1ba43d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1616,7 +1615,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 47,
|
||||
"id": "2e66e613-4aca-4296-a984-ddd0d80c6578",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1660,7 +1659,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 48,
|
||||
"id": "e4600713-c51e-4f53-bf58-040a6eb362b8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1693,7 +1692,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 49,
|
||||
"id": "9dfb48f0-bc3f-46a5-9844-33b6c9b0f4df",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1759,7 +1758,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 50,
|
||||
"id": "2a7f908a-e9ec-446a-b407-fb6dbf05c806",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1782,7 +1781,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 51,
|
||||
"id": "753865ed-79c5-48b1-b9f2-ccb132ff1d2f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1790,8 +1789,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([4.5100, -inf, -inf, 6.7500, -inf, -inf, -inf, 6.2800, -inf])\n",
|
||||
"torch.Size([9])\n"
|
||||
"tensor([4.5100, -inf, -inf, 6.7500, -inf, -inf, -inf, 6.2800, -inf])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1802,13 +1800,12 @@
|
||||
" other=next_token_logits\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(new_logits)\n",
|
||||
"print(new_logits.shape)"
|
||||
"print(new_logits)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 52,
|
||||
"id": "4844f000-c329-4e7e-aa89-16a2c4ebee43",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1844,7 +1841,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 68,
|
||||
"id": "8e318891-bcc0-4d71-b147-33ce55febfa3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1885,18 +1882,10 @@
|
||||
" return idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8817c673-6d27-417c-b2c1-3cff394a340d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- **Exercise:** What are the settings for `generate` to force deterministic behavior?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "3adb4df3-1150-44e2-93c7-532d205901f9",
|
||||
"execution_count": 69,
|
||||
"id": "aa2a0d7d-0457-42d1-ab9d-bd67683e7ed8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1904,10 +1893,9 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Every effort moves you?\"\n",
|
||||
" Every effort moves you know terrace _not brush.\"\n",
|
||||
"\n",
|
||||
"\"Yes--quite insensible to me, a single one in the house.\"\n",
|
||||
"\n"
|
||||
"\"Never a little wild in a and Mrs. G\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1919,13 +1907,21 @@
|
||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
|
||||
" max_new_tokens=20,\n",
|
||||
" context_size=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
||||
" top_k=2,\n",
|
||||
" temperature=1.25\n",
|
||||
" top_k=10,\n",
|
||||
" temperature=1.5\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8817c673-6d27-417c-b2c1-3cff394a340d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- **Exercise:** What are the settings for `generate` to force deterministic behavior?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4e2002ca-f4c1-48af-9e0a-88bfc163ba0b",
|
||||
@ -1954,7 +1950,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 70,
|
||||
"id": "3d67d869-ac04-4382-bcfb-c96d1ca80d47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1972,7 +1968,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 71,
|
||||
"id": "9d57d914-60a3-47f1-b499-5352f4c457cb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1993,7 +1989,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 72,
|
||||
"id": "bbd175bb-edf4-450e-a6de-d3e8913c6532",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -2008,13 +2004,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 73,
|
||||
"id": "8a0c7295-c822-43bf-9286-c45abc542868",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checkpoint = torch.load(\"model_and_optimizer.pth\")\n",
|
||||
"\n",
|
||||
"model = GPTModel(GPT_CONFIG_124M)\n",
|
||||
"model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
|
||||
"\n",
|
||||
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)\n",
|
||||
"optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
|
||||
"model.train();"
|
||||
]
|
||||
@ -2057,7 +2057,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 74,
|
||||
"id": "fb9fdf02-972a-444e-bf65-8ffcaaf30ce8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -2067,7 +2067,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 75,
|
||||
"id": "a0747edc-559c-44ef-a93f-079d60227e3f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -2087,105 +2087,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 84,
|
||||
"id": "c5bc89eb-4d39-4287-9b0c-e459ebe7f5ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import numpy as np\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def download_and_load_gpt2(model_size, models_dir):\n",
|
||||
" # Validate model size\n",
|
||||
" allowed_sizes = (\"124M\", \"355M\", \"774M\", \"1558M\")\n",
|
||||
" if model_size not in allowed_sizes:\n",
|
||||
" raise ValueError(f\"Model size not in {allowed_sizes}\")\n",
|
||||
"\n",
|
||||
" # Define paths\n",
|
||||
" model_dir = os.path.join(models_dir, model_size)\n",
|
||||
" base_url = \"https://openaipublic.blob.core.windows.net/gpt-2/models\"\n",
|
||||
" filenames = [\n",
|
||||
" \"checkpoint\", \"encoder.json\", \"hparams.json\",\n",
|
||||
" \"model.ckpt.data-00000-of-00001\", \"model.ckpt.index\",\n",
|
||||
" \"model.ckpt.meta\", \"vocab.bpe\"\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # Download files\n",
|
||||
" os.makedirs(model_dir, exist_ok=True)\n",
|
||||
" for filename in filenames:\n",
|
||||
" file_url = os.path.join(base_url, model_size, filename)\n",
|
||||
" file_path = os.path.join(model_dir, filename)\n",
|
||||
" download_file(file_url, file_path)\n",
|
||||
"\n",
|
||||
" # Load hparams and params\n",
|
||||
" tf_ckpt_path = tf.train.latest_checkpoint(model_dir)\n",
|
||||
" hparams = json.load(open(os.path.join(model_dir, \"hparams.json\")))\n",
|
||||
" params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)\n",
|
||||
"\n",
|
||||
" return hparams, params\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def download_file(url, destination):\n",
|
||||
" # Send a GET request to download the file in streaming mode\n",
|
||||
" response = requests.get(url, stream=True)\n",
|
||||
"\n",
|
||||
" # Get the total file size from headers, defaulting to 0 if not present\n",
|
||||
" file_size = int(response.headers.get(\"content-length\", 0))\n",
|
||||
"\n",
|
||||
" # Check if file exists and has the same size\n",
|
||||
" if os.path.exists(destination):\n",
|
||||
" file_size_local = os.path.getsize(destination)\n",
|
||||
" if file_size == file_size_local:\n",
|
||||
" print(f\"File already exists and is up-to-date: {destination}\")\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" # Define the block size for reading the file\n",
|
||||
" block_size = 1024 # 1 Kilobyte\n",
|
||||
"\n",
|
||||
" # Initialize the progress bar with total file size\n",
|
||||
" progress_bar_description = url.split(\"/\")[-1] # Extract filename from URL\n",
|
||||
" with tqdm(total=file_size, unit=\"iB\", unit_scale=True, desc=progress_bar_description) as progress_bar:\n",
|
||||
" # Open the destination file in binary write mode\n",
|
||||
" with open(destination, \"wb\") as file:\n",
|
||||
" # Iterate over the file data in chunks\n",
|
||||
" for chunk in response.iter_content(block_size):\n",
|
||||
" progress_bar.update(len(chunk)) # Update progress bar\n",
|
||||
" file.write(chunk) # Write the chunk to the file\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_gpt2_params_from_tf_ckpt(ckpt_path, hparams):\n",
|
||||
" # Initialize parameters dictionary with empty blocks for each layer\n",
|
||||
" params = {\"blocks\": [{} for _ in range(hparams[\"n_layer\"])]}\n",
|
||||
"\n",
|
||||
" # Iterate over each variable in the checkpoint\n",
|
||||
" for name, _ in tf.train.list_variables(ckpt_path):\n",
|
||||
" # Load the variable and remove singleton dimensions\n",
|
||||
" variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))\n",
|
||||
"\n",
|
||||
" # Process the variable name to extract relevant parts\n",
|
||||
" variable_name_parts = name.split(\"/\")[1:] # Skip the 'model/' prefix\n",
|
||||
"\n",
|
||||
" # Identify the target dictionary for the variable\n",
|
||||
" target_dict = params\n",
|
||||
" if variable_name_parts[0].startswith(\"h\"):\n",
|
||||
" layer_number = int(variable_name_parts[0][1:])\n",
|
||||
" target_dict = params[\"blocks\"][layer_number]\n",
|
||||
"\n",
|
||||
" # Recursively access or create nested dictionaries\n",
|
||||
" for key in variable_name_parts[1:-1]:\n",
|
||||
" target_dict = target_dict.setdefault(key, {})\n",
|
||||
"\n",
|
||||
" # Assign the variable array to the last key\n",
|
||||
" last_key = variable_name_parts[-1]\n",
|
||||
" target_dict[last_key] = variable_array\n",
|
||||
"\n",
|
||||
" return params"
|
||||
"from gpt_download import download_and_load_gpt2"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2198,21 +2105,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 85,
|
||||
"id": "76271dd7-108d-4f5b-9c01-6ae0aac4b395",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"File already exists and is up-to-date: gpt2/124M/checkpoint\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/encoder.json\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/hparams.json\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
|
||||
"checkpoint: 100%|████████████████████████████| 77.0/77.0 [00:00<00:00, 132kiB/s]\n",
|
||||
"encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 3.54MiB/s]\n",
|
||||
"hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 52.9kiB/s]\n",
|
||||
"model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [01:02<00:00, 7.93MiB/s]\n",
|
||||
"model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 1.48MiB/s]\n",
|
||||
"model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 1.93MiB/s]\n",
|
||||
"vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.30MiB/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -2222,7 +2129,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 86,
|
||||
"id": "b1a31951-d971-4a6e-9c43-11ee1168ec6a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -2240,7 +2147,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 87,
|
||||
"id": "857c8331-130e-46ba-921d-fa35d7a73cfe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -2286,7 +2193,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"execution_count": 88,
|
||||
"id": "9fef90dd-0654-4667-844f-08e28339ef7d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -2319,7 +2226,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"execution_count": 89,
|
||||
"id": "f9a92229-c002-49a6-8cfb-248297ad8296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -2332,7 +2239,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"execution_count": 90,
|
||||
"id": "f22d5d95-ca5a-425c-a9ec-fc432a12d4e9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -2385,7 +2292,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 91,
|
||||
"id": "1f690253-f845-4347-b7b6-43fabbd2affa",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -2446,10 +2353,18 @@
|
||||
"id": "fc7ed189-a633-458c-bf12-4f70b42684b8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- See the [gpt_train.py](train_gpt.py) script containing a self-contained training script\n",
|
||||
"- The [gpt_generate.py](generate_gpt.py) script loads pretrained weights from OpenAI and generates text based on a prompt\n",
|
||||
"- See the [gpt_train.py](gpt_train.py) script containing a self-contained training script\n",
|
||||
"- The [gpt_generate.py](gpt_generate.py) script loads pretrained weights from OpenAI and generates text based on a prompt\n",
|
||||
"- You can find the exercise solutions in [exercise-solutions.ipynb](exercise-solutions.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2ef9585d-ea2e-4c04-9dd5-71d003a9dd07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
93
ch05/01_main-chapter-code/gpt_download.py
Normal file
93
ch05/01_main-chapter-code/gpt_download.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def download_and_load_gpt2(model_size, models_dir):
|
||||
# Validate model size
|
||||
allowed_sizes = ("124M", "355M", "774M", "1558M")
|
||||
if model_size not in allowed_sizes:
|
||||
raise ValueError(f"Model size not in {allowed_sizes}")
|
||||
|
||||
# Define paths
|
||||
model_dir = os.path.join(models_dir, model_size)
|
||||
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||
filenames = [
|
||||
"checkpoint", "encoder.json", "hparams.json",
|
||||
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
||||
"model.ckpt.meta", "vocab.bpe"
|
||||
]
|
||||
|
||||
# Download files
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
for filename in filenames:
|
||||
file_url = os.path.join(base_url, model_size, filename)
|
||||
file_path = os.path.join(model_dir, filename)
|
||||
download_file(file_url, file_path)
|
||||
|
||||
# Load hparams and params
|
||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||
hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
|
||||
|
||||
return hparams, params
|
||||
|
||||
|
||||
def download_file(url, destination):
|
||||
# Send a GET request to download the file in streaming mode
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
# Get the total file size from headers, defaulting to 0 if not present
|
||||
file_size = int(response.headers.get("content-length", 0))
|
||||
|
||||
# Check if file exists and has the same size
|
||||
if os.path.exists(destination):
|
||||
file_size_local = os.path.getsize(destination)
|
||||
if file_size == file_size_local:
|
||||
print(f"File already exists and is up-to-date: {destination}")
|
||||
return
|
||||
|
||||
# Define the block size for reading the file
|
||||
block_size = 1024 # 1 Kilobyte
|
||||
|
||||
# Initialize the progress bar with total file size
|
||||
progress_bar_description = url.split("/")[-1] # Extract filename from URL
|
||||
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
|
||||
# Open the destination file in binary write mode
|
||||
with open(destination, "wb") as file:
|
||||
# Iterate over the file data in chunks
|
||||
for chunk in response.iter_content(block_size):
|
||||
progress_bar.update(len(chunk)) # Update progress bar
|
||||
file.write(chunk) # Write the chunk to the file
|
||||
|
||||
|
||||
def load_gpt2_params_from_tf_ckpt(ckpt_path, hparams):
|
||||
# Initialize parameters dictionary with empty blocks for each layer
|
||||
params = {"blocks": [{} for _ in range(hparams["n_layer"])]}
|
||||
|
||||
# Iterate over each variable in the checkpoint
|
||||
for name, _ in tf.train.list_variables(ckpt_path):
|
||||
# Load the variable and remove singleton dimensions
|
||||
variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
|
||||
|
||||
# Process the variable name to extract relevant parts
|
||||
variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
|
||||
|
||||
# Identify the target dictionary for the variable
|
||||
target_dict = params
|
||||
if variable_name_parts[0].startswith("h"):
|
||||
layer_number = int(variable_name_parts[0][1:])
|
||||
target_dict = params["blocks"][layer_number]
|
||||
|
||||
# Recursively access or create nested dictionaries
|
||||
for key in variable_name_parts[1:-1]:
|
||||
target_dict = target_dict.setdefault(key, {})
|
||||
|
||||
# Assign the variable array to the last key
|
||||
last_key = variable_name_parts[-1]
|
||||
target_dict[last_key] = variable_array
|
||||
|
||||
return params
|
Loading…
x
Reference in New Issue
Block a user