From 0074c989688cd61247c711aab1fc952130856671 Mon Sep 17 00:00:00 2001 From: rasbt Date: Mon, 15 Jan 2024 17:07:55 -0600 Subject: [PATCH] add download utilities for vocab and encoder files --- .../bpe_openai_gpt2.py | 28 ++++++++- .../compare-bpe-tiktoken.ipynb | 61 +++++++++++-------- 2 files changed, 63 insertions(+), 26 deletions(-) diff --git a/ch02/02_bonus_bytepair-encoder/bpe_openai_gpt2.py b/ch02/02_bonus_bytepair-encoder/bpe_openai_gpt2.py index 3b98cdd..0d85e95 100644 --- a/ch02/02_bonus_bytepair-encoder/bpe_openai_gpt2.py +++ b/ch02/02_bonus_bytepair-encoder/bpe_openai_gpt2.py @@ -3,6 +3,9 @@ Byte pair encoding utilities Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py +And modified code (download_vocab) from +https://github.com/openai/gpt-2/blob/master/download_model.py + Modified MIT License Software Copyright (c) 2019 OpenAI @@ -34,6 +37,8 @@ OR OTHER DEALINGS IN THE SOFTWARE. import os import json import regex as re +import requests +from tqdm import tqdm from functools import lru_cache @lru_cache() @@ -145,4 +150,25 @@ def get_encoder(model_name, models_dir): return Encoder( encoder=encoder, bpe_merges=bpe_merges, - ) \ No newline at end of file + ) + + +def download_vocab(): + # Modified code from + subdir = 'gpt2_model' + if not os.path.exists(subdir): + os.makedirs(subdir) + subdir = subdir.replace('\\','/') # needed for Windows + + for filename in ['encoder.json', 'vocab.bpe']: + + r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M" + "/" + filename, stream=True) + + with open(os.path.join(subdir, filename), 'wb') as f: + file_size = int(r.headers["content-length"]) + chunk_size = 1000 + with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + pbar.update(chunk_size) \ No newline at end of file diff --git a/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb b/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb index 6dbe7e6..8fdbc21 100644 --- a/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb +++ b/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb @@ -125,22 +125,41 @@ "metadata": {}, "outputs": [], "source": [ - "from bpe_openai_gpt2 import get_encoder" + "from bpe_openai_gpt2 import get_encoder, download_vocab" ] }, { "cell_type": "code", "execution_count": 8, - "id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0", + "id": "35dd8d7c-8c12-4b68-941a-0fd05882dd45", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching encoder.json: 1.04Mit [00:00, 3.03Mit/s] \n", + "Fetching vocab.bpe: 457kit [00:00, 2.36Mit/s] \n" + ] + } + ], "source": [ - "orig_tokenizer = get_encoder(model_name=\"gpt2\", models_dir=\".\")" + "download_vocab()" ] }, { "cell_type": "code", "execution_count": 9, + "id": "1888a7a9-9c40-4fe0-99b4-ebd20aa1ffd0", + "metadata": {}, + "outputs": [], + "source": [ + "orig_tokenizer = get_encoder(model_name=\"gpt2_model\", models_dir=\".\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "id": "2740510c-a78a-4fba-ae18-2b156ba2dfef", "metadata": {}, "outputs": [ @@ -160,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "434d115e-990d-42ad-88dd-31323a96e10f", "metadata": {}, "outputs": [ @@ -188,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "5bfff386-f725-4137-9c50-e5da0c38bea0", "metadata": {}, "outputs": [], @@ -198,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "e9077bf4-f91f-42ad-ab76-f3d89128510e", "metadata": {}, "outputs": [ @@ -216,7 +235,7 @@ "'4.33.2'" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -270,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "id": "a61bb445-b151-4a2f-8180-d4004c503754", "metadata": {}, "outputs": [], @@ -281,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "id": "57f7c0a3-c1fd-4313-af34-68e78eb33653", "metadata": {}, "outputs": [ @@ -289,7 +308,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4.17 ms ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "4.12 ms ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -299,7 +318,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "id": "036dd628-3591-46c9-a5ce-b20b105a8062", "metadata": {}, "outputs": [ @@ -307,7 +326,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.68 ms ± 9.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + "1.75 ms ± 8.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -317,7 +336,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 19, "id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90", "metadata": {}, "outputs": [ @@ -332,7 +351,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "8.81 ms ± 51.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "9.12 ms ± 856 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -342,7 +361,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 20, "id": "7117107f-22a6-46b4-a442-712d50b3ac7a", "metadata": {}, "outputs": [ @@ -350,21 +369,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "8.8 ms ± 74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "8.63 ms ± 247 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0bcbacd5-b64f-4186-ab12-949b9483f556", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {