mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-26 08:34:22 +00:00
add download utilities for vocab and encoder files
This commit is contained in:
parent
dfe2c3b46f
commit
0074c98968
@ -3,6 +3,9 @@ Byte pair encoding utilities
|
|||||||
|
|
||||||
Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
|
|
||||||
|
And modified code (download_vocab) from
|
||||||
|
https://github.com/openai/gpt-2/blob/master/download_model.py
|
||||||
|
|
||||||
Modified MIT License
|
Modified MIT License
|
||||||
|
|
||||||
Software Copyright (c) 2019 OpenAI
|
Software Copyright (c) 2019 OpenAI
|
||||||
@ -34,6 +37,8 @@ OR OTHER DEALINGS IN THE SOFTWARE.
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import regex as re
|
import regex as re
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -146,3 +151,24 @@ def get_encoder(model_name, models_dir):
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
bpe_merges=bpe_merges,
|
bpe_merges=bpe_merges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_vocab():
|
||||||
|
# Modified code from
|
||||||
|
subdir = 'gpt2_model'
|
||||||
|
if not os.path.exists(subdir):
|
||||||
|
os.makedirs(subdir)
|
||||||
|
subdir = subdir.replace('\\','/') # needed for Windows
|
||||||
|
|
||||||
|
for filename in ['encoder.json', 'vocab.bpe']:
|
||||||
|
|
||||||
|
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M" + "/" + filename, stream=True)
|
||||||
|
|
||||||
|
with open(os.path.join(subdir, filename), 'wb') as f:
|
||||||
|
file_size = int(r.headers["content-length"])
|
||||||
|
chunk_size = 1000
|
||||||
|
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
|
||||||
|
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
|
||||||
|
for chunk in r.iter_content(chunk_size=chunk_size):
|
||||||
|
f.write(chunk)
|
||||||
|
pbar.update(chunk_size)
|
@ -125,22 +125,41 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from bpe_openai_gpt2 import get_encoder"
|
"from bpe_openai_gpt2 import get_encoder, download_vocab"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0",
|
"id": "35dd8d7c-8c12-4b68-941a-0fd05882dd45",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Fetching encoder.json: 1.04Mit [00:00, 3.03Mit/s] \n",
|
||||||
|
"Fetching vocab.bpe: 457kit [00:00, 2.36Mit/s] \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"orig_tokenizer = get_encoder(model_name=\"gpt2\", models_dir=\".\")"
|
"download_vocab()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
|
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"orig_tokenizer = get_encoder(model_name=\"gpt2_model\", models_dir=\".\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
"id": "2740510c-a78a-4fba-ae18-2b156ba2dfef",
|
"id": "2740510c-a78a-4fba-ae18-2b156ba2dfef",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -160,7 +179,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 11,
|
||||||
"id": "434d115e-990d-42ad-88dd-31323a96e10f",
|
"id": "434d115e-990d-42ad-88dd-31323a96e10f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -188,7 +207,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 12,
|
||||||
"id": "5bfff386-f725-4137-9c50-e5da0c38bea0",
|
"id": "5bfff386-f725-4137-9c50-e5da0c38bea0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -198,7 +217,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -216,7 +235,7 @@
|
|||||||
"'4.33.2'"
|
"'4.33.2'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -270,7 +289,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 16,
|
||||||
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -281,7 +300,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 17,
|
||||||
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -289,7 +308,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"4.17 ms ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"4.12 ms ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -299,7 +318,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 18,
|
||||||
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -307,7 +326,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"1.68 ms ± 9.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
|
"1.75 ms ± 8.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -317,7 +336,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": 19,
|
||||||
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -332,7 +351,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"8.81 ms ± 51.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"9.12 ms ± 856 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -342,7 +361,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": 20,
|
||||||
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
|
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -350,21 +369,13 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"8.8 ms ± 74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"8.63 ms ± 247 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
|
"%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "0bcbacd5-b64f-4186-ab12-949b9483f556",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user