add download utilities for vocab and encoder files

This commit is contained in:
rasbt 2024-01-15 17:07:55 -06:00
parent dfe2c3b46f
commit 0074c98968
2 changed files with 63 additions and 26 deletions

View File

@ -3,6 +3,9 @@ Byte pair encoding utilities
Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py
And modified code (download_vocab) from
https://github.com/openai/gpt-2/blob/master/download_model.py
Modified MIT License Modified MIT License
Software Copyright (c) 2019 OpenAI Software Copyright (c) 2019 OpenAI
@ -34,6 +37,8 @@ OR OTHER DEALINGS IN THE SOFTWARE.
import os import os
import json import json
import regex as re import regex as re
import requests
from tqdm import tqdm
from functools import lru_cache from functools import lru_cache
@lru_cache() @lru_cache()
@ -146,3 +151,24 @@ def get_encoder(model_name, models_dir):
encoder=encoder, encoder=encoder,
bpe_merges=bpe_merges, bpe_merges=bpe_merges,
) )
def download_vocab():
# Modified code from
subdir = 'gpt2_model'
if not os.path.exists(subdir):
os.makedirs(subdir)
subdir = subdir.replace('\\','/') # needed for Windows
for filename in ['encoder.json', 'vocab.bpe']:
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M" + "/" + filename, stream=True)
with open(os.path.join(subdir, filename), 'wb') as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)

View File

@ -125,22 +125,41 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from bpe_openai_gpt2 import get_encoder" "from bpe_openai_gpt2 import get_encoder, download_vocab"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 8,
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0", "id": "35dd8d7c-8c12-4b68-941a-0fd05882dd45",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fetching encoder.json: 1.04Mit [00:00, 3.03Mit/s] \n",
"Fetching vocab.bpe: 457kit [00:00, 2.36Mit/s] \n"
]
}
],
"source": [ "source": [
"orig_tokenizer = get_encoder(model_name=\"gpt2\", models_dir=\".\")" "download_vocab()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 9,
"id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0",
"metadata": {},
"outputs": [],
"source": [
"orig_tokenizer = get_encoder(model_name=\"gpt2_model\", models_dir=\".\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2740510c-a78a-4fba-ae18-2b156ba2dfef", "id": "2740510c-a78a-4fba-ae18-2b156ba2dfef",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -160,7 +179,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"id": "434d115e-990d-42ad-88dd-31323a96e10f", "id": "434d115e-990d-42ad-88dd-31323a96e10f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -188,7 +207,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 12,
"id": "5bfff386-f725-4137-9c50-e5da0c38bea0", "id": "5bfff386-f725-4137-9c50-e5da0c38bea0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -198,7 +217,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 13,
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e", "id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -216,7 +235,7 @@
"'4.33.2'" "'4.33.2'"
] ]
}, },
"execution_count": 12, "execution_count": 13,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -270,7 +289,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 16,
"id": "a61bb445-b151-4a2f-8180-d4004c503754", "id": "a61bb445-b151-4a2f-8180-d4004c503754",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -281,7 +300,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 17,
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653", "id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -289,7 +308,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"4.17 ms ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" "4.12 ms ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
@ -299,7 +318,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 18,
"id": "036dd628-3591-46c9-a5ce-b20b105a8062", "id": "036dd628-3591-46c9-a5ce-b20b105a8062",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -307,7 +326,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"1.68 ms ± 9.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" "1.75 ms ± 8.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
] ]
} }
], ],
@ -317,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 19,
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90", "id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -332,7 +351,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"8.81 ms ± 51.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" "9.12 ms ± 856 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
@ -342,7 +361,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 20,
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a", "id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -350,21 +369,13 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"8.8 ms ± 74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" "8.63 ms ± 247 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
"source": [ "source": [
"%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]" "%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bcbacd5-b64f-4186-ab12-949b9483f556",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {