fix: preserve newline tokens in BPE encoder (#495)

* fix: preserve newline tokens in BPE encoder

* further fixes

* more fixes

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
Austin Welch 2025-01-21 13:47:15 -05:00 committed by GitHub
parent 60acb94894
commit 0f35e370ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,7 +39,7 @@
"- The BPE algorithm was originally described in 1994: \"[A New Algorithm for Data Compression](http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM)\" by Philip Gage\n",
"- Most projects, including Llama 3, nowadays use OpenAI's open-source [tiktoken library](https://github.com/openai/tiktoken) due to its computational performance; it allows loading pretrained GPT-2 and GPT-4 tokenizers, for example (the Llama 3 models were trained using the GPT-4 tokenizer as well)\n",
"- The difference between the implementations above and my implementation in this notebook, besides it being is that it also includes a function for training the tokenizer (for educational purposes)\n",
"- There's also an implementation called [minBPE](https://github.com/karpathy/minbpe) with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to `minbpe` my implementation additionally allows loading the original OpenAI tokenizer vocabulary and merges"
"- There's also an implementation called [minBPE](https://github.com/karpathy/minbpe) with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to `minbpe` my implementation additionally allows loading the original OpenAI tokenizer vocabulary and BPE \"merges\" (additionally, Hugging Face tokenizers are also capable of training and loading various tokenizers; see [this GitHub discussion](https://github.com/rasbt/LLMs-from-scratch/discussions/485) by a reader who trained a BPE tokenizer on the Nepali language for more info)"
]
},
{
@ -382,7 +382,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 4,
"id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d",
"metadata": {},
"outputs": [],
@ -431,8 +431,8 @@
" unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)\n",
"\n",
" # Optionally, ensure 'Ġ' is included if it is relevant to your text processing\n",
" if 'Ġ' not in unique_chars:\n",
" unique_chars.append('Ġ')\n",
" if \"Ġ\" not in unique_chars:\n",
" unique_chars.append(\"Ġ\")\n",
"\n",
" # Now create the vocab and inverse vocab dictionaries\n",
" self.vocab = {i: char for i, char in enumerate(unique_chars)}\n",
@ -474,9 +474,23 @@
" # Load vocabulary\n",
" with open(vocab_path, \"r\", encoding=\"utf-8\") as file:\n",
" loaded_vocab = json.load(file)\n",
" # loaded_vocab maps token_str to token_id\n",
" self.vocab = {int(v): k for k, v in loaded_vocab.items()} # token_id: token_str\n",
" self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()} # token_str: token_id\n",
" # Convert loaded vocabulary to correct format\n",
" self.vocab = {int(v): k for k, v in loaded_vocab.items()}\n",
" self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()}\n",
"\n",
" # Handle newline character without adding a new token\n",
" if \"\\n\" not in self.inverse_vocab:\n",
" # Use an existing token ID as a placeholder for '\\n'\n",
" # Preferentially use \"<|endoftext|>\" if available\n",
" fallback_token = next((token for token in [\"<|endoftext|>\", \"Ġ\", \"\"] if token in self.inverse_vocab), None)\n",
" if fallback_token is not None:\n",
" newline_token_id = self.inverse_vocab[fallback_token]\n",
" else:\n",
" # If no fallback token is available, raise an error\n",
" raise KeyError(\"No suitable token found in vocabulary to map '\\\\n'.\")\n",
"\n",
" self.inverse_vocab[\"\\n\"] = newline_token_id\n",
" self.vocab[newline_token_id] = \"\\n\"\n",
"\n",
" # Load BPE merges\n",
" with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
@ -487,17 +501,15 @@
"\n",
" for rank, line in enumerate(lines):\n",
" pair = tuple(line.strip().split())\n",
" if len(pair) != 2:\n",
" print(f\"Line {rank+1} has more than 2 entries: {line.strip()}\")\n",
" continue\n",
" token1, token2 = pair\n",
" if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
" token_id1 = self.inverse_vocab[token1]\n",
" token_id2 = self.inverse_vocab[token2]\n",
" merged_token = token1 + token2\n",
" if merged_token in self.inverse_vocab:\n",
" merged_token_id = self.inverse_vocab[merged_token]\n",
" self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n",
" if len(pair) == 2:\n",
" token1, token2 = pair\n",
" if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
" token_id1 = self.inverse_vocab[token1]\n",
" token_id2 = self.inverse_vocab[token2]\n",
" merged_token = token1 + token2\n",
" if merged_token in self.inverse_vocab:\n",
" merged_token_id = self.inverse_vocab[merged_token]\n",
" self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n",
" # print(f\"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})\")\n",
" else:\n",
" print(f\"Merged token '{merged_token}' not found in vocab. Skipping.\")\n",
@ -515,21 +527,27 @@
" List[int]: The list of token IDs.\n",
" \"\"\"\n",
" tokens = []\n",
" # Split text into tokens, keeping newlines intact\n",
" words = text.replace(\"\\n\", \" \\n \").split() # Ensure '\\n' is treated as a separate token\n",
"\n",
" for i, word in enumerate(words):\n",
" if i > 0 and not word.startswith(\"\\n\"):\n",
" tokens.append(\"Ġ\" + word) # Add 'Ġ' to words that follow a space or newline\n",
" else:\n",
" tokens.append(word) # Handle first word or standalone '\\n'\n",
" # First split on newlines to preserve them\n",
" lines = text.split(\"\\n\")\n",
" for i, line in enumerate(lines):\n",
" if i > 0:\n",
" tokens.append(\"\\n\") # Add newline token separately\n",
" words = line.split()\n",
" for j, word in enumerate(words):\n",
" if j == 0:\n",
" if i > 0: # Start of a new line but not the first line\n",
" tokens.append(\"Ġ\" + word) # Ensure it's marked as a new segment\n",
" else:\n",
" tokens.append(word)\n",
" else:\n",
" # Prefix words in the middle of a line with 'Ġ'\n",
" tokens.append(\"Ġ\" + word)\n",
"\n",
" token_ids = []\n",
" for token in tokens:\n",
" if token in self.inverse_vocab:\n",
" # token is contained in the vocabulary as is\n",
" token_id = self.inverse_vocab[token]\n",
" token_ids.append(token_id)\n",
" token_ids.append(self.inverse_vocab[token])\n",
" else:\n",
" # Attempt to handle subword tokenization via BPE\n",
" sub_token_ids = self.tokenize_with_bpe(token)\n",
@ -587,12 +605,15 @@
" str: The decoded string.\n",
" \"\"\"\n",
" decoded_string = \"\"\n",
" for token_id in token_ids:\n",
" for i, token_id in enumerate(token_ids):\n",
" if token_id not in self.vocab:\n",
" raise ValueError(f\"Token ID {token_id} not found in vocab.\")\n",
" token = self.vocab[token_id]\n",
" if token.startswith(\"Ġ\"):\n",
" # Replace 'Ġ' with a space\n",
" if token == \"\\n\":\n",
" if decoded_string and not decoded_string.endswith(\" \"):\n",
" decoded_string += \" \" # Add space if not present before a newline\n",
" decoded_string += token\n",
" elif token.startswith(\"Ġ\"):\n",
" decoded_string += \" \" + token[1:]\n",
" else:\n",
" decoded_string += token\n",
@ -634,8 +655,8 @@
" with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
" merges_list = json.load(file)\n",
" for merge in merges_list:\n",
" pair = tuple(merge['pair'])\n",
" new_id = merge['new_id']\n",
" pair = tuple(merge[\"pair\"])\n",
" new_id = merge[\"new_id\"]\n",
" self.bpe_merges[pair] = new_id\n",
"\n",
" @lru_cache(maxsize=None)\n",
@ -714,7 +735,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 5,
"id": "4d197cad-ed10-4a42-b01c-a763859781fb",
"metadata": {},
"outputs": [],
@ -745,7 +766,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 6,
"id": "027348fd-d52f-4396-93dd-38eed142df9b",
"metadata": {},
"outputs": [],
@ -764,7 +785,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 7,
"id": "f705a283-355e-4460-b940-06bbc2ae4e61",
"metadata": {},
"outputs": [
@ -791,7 +812,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 8,
"id": "3da42d1c-f75c-4ba7-a6c5-4cb8543d4a44",
"metadata": {},
"outputs": [
@ -825,7 +846,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 9,
"id": "e1db5cce-e015-412b-ad56-060b8b638078",
"metadata": {},
"outputs": [
@ -845,7 +866,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 10,
"id": "1ed1b344-f7d4-4e9e-ac34-2a04b5c5b7a8",
"metadata": {},
"outputs": [
@ -881,7 +902,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 11,
"id": "da0e1faf-1933-43d9-b681-916c282a8f86",
"metadata": {},
"outputs": [
@ -899,7 +920,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 12,
"id": "8b690e83-5d6b-409a-804e-321c287c24a4",
"metadata": {},
"outputs": [
@ -925,7 +946,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 13,
"id": "2b9e6289-92cb-4d88-b3c8-e836d7c8095f",
"metadata": {},
"outputs": [
@ -979,7 +1000,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 14,
"id": "c7056cb1-a9a3-4cf6-8364-29fb493ae240",
"metadata": {},
"outputs": [
@ -989,13 +1010,38 @@
"'This is some text.'"
]
},
"execution_count": 87,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(tokenizer.encode(\"This is some text.\"))"
"tokenizer.decode(\n",
" tokenizer.encode(\"This is some text.\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "37bc6753-8f35-4ec7-b23e-df4a12103cb4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'This is some text with \\n newline characters.'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(\n",
" tokenizer.encode(\"This is some text with \\n newline characters.\")\n",
")"
]
},
{
@ -1016,7 +1062,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 16,
"id": "955181cb-0910-4c6a-9c22-d8292a3ec1fc",
"metadata": {},
"outputs": [],
@ -1027,7 +1073,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 17,
"id": "6e5ccfe7-ac67-42f3-b727-87886a8867f1",
"metadata": {},
"outputs": [],
@ -1047,7 +1093,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 18,
"id": "00d9bf8f-756f-48bf-81b8-b890e2c2ef13",
"metadata": {},
"outputs": [
@ -1063,6 +1109,29 @@
"print(tokenizer2.decode(token_ids))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e7addb64-2892-4e1c-85dd-4f5152740099",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'This is some text with \\n newline characters.'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer2.decode(\n",
" tokenizer2.encode(\"This is some text with \\n newline characters.\")\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b24d10b2-1ab8-44ee-b51a-14248e30d662",
@ -1082,7 +1151,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 20,
"id": "b45b4366-2c2b-4309-9a14-febf3add8512",
"metadata": {},
"outputs": [
@ -1090,8 +1159,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"vocab.bpe already exists\n",
"encoder.json already exists\n"
"vocab.bpe already exists in ../02_bonus_bytepair-encoder/gpt2_model/vocab.bpe\n",
"encoder.json already exists in ../02_bonus_bytepair-encoder/gpt2_model/encoder.json\n"
]
}
],
@ -1139,7 +1208,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 21,
"id": "74306e6c-47d3-45a3-9e0f-93f7303ef601",
"metadata": {},
"outputs": [],
@ -1160,7 +1229,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 22,
"id": "2bb722b4-dbf5-4a0c-9120-efda3293f132",
"metadata": {},
"outputs": [
@ -1170,7 +1239,7 @@
"50257"
]
},
"execution_count": 93,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@ -1189,7 +1258,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 23,
"id": "e4866de7-fb32-4dd6-a878-469ec734641c",
"metadata": {},
"outputs": [
@ -1209,7 +1278,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": 24,
"id": "3da8d9b2-af55-4b09-95d7-fabd983e919e",
"metadata": {},
"outputs": [
@ -1225,30 +1294,6 @@
"print(tokenizer_gpt2.decode(token_ids))"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "460deb85-8de7-40c7-ba18-3c17831fa8ab",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1212, 318, 617, 2420]"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tiktoken\n",
"\n",
"tik_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"tik_tokenizer.encode(input_text)"
]
},
{
"cell_type": "markdown",
"id": "b3b1e2dc-f69b-4533-87ef-549e6fb9b5a0",
@ -1303,7 +1348,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,