diff --git a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb index d14472f..a750af4 100644 --- a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb +++ b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb @@ -39,7 +39,7 @@ "- The BPE algorithm was originally described in 1994: \"[A New Algorithm for Data Compression](http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM)\" by Philip Gage\n", "- Most projects, including Llama 3, nowadays use OpenAI's open-source [tiktoken library](https://github.com/openai/tiktoken) due to its computational performance; it allows loading pretrained GPT-2 and GPT-4 tokenizers, for example (the Llama 3 models were trained using the GPT-4 tokenizer as well)\n", "- The difference between the implementations above and my implementation in this notebook, besides it being is that it also includes a function for training the tokenizer (for educational purposes)\n", - "- There's also an implementation called [minBPE](https://github.com/karpathy/minbpe) with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to `minbpe` my implementation additionally allows loading the original OpenAI tokenizer vocabulary and merges" + "- There's also an implementation called [minBPE](https://github.com/karpathy/minbpe) with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to `minbpe` my implementation additionally allows loading the original OpenAI tokenizer vocabulary and BPE \"merges\" (additionally, Hugging Face tokenizers are also capable of training and loading various tokenizers; see [this GitHub discussion](https://github.com/rasbt/LLMs-from-scratch/discussions/485) by a reader who trained a BPE tokenizer on the Nepali language for more info)" ] }, { @@ -382,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 4, "id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d", "metadata": {}, "outputs": [], @@ -431,8 +431,8 @@ " unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)\n", "\n", " # Optionally, ensure 'Ġ' is included if it is relevant to your text processing\n", - " if 'Ġ' not in unique_chars:\n", - " unique_chars.append('Ġ')\n", + " if \"Ġ\" not in unique_chars:\n", + " unique_chars.append(\"Ġ\")\n", "\n", " # Now create the vocab and inverse vocab dictionaries\n", " self.vocab = {i: char for i, char in enumerate(unique_chars)}\n", @@ -474,9 +474,23 @@ " # Load vocabulary\n", " with open(vocab_path, \"r\", encoding=\"utf-8\") as file:\n", " loaded_vocab = json.load(file)\n", - " # loaded_vocab maps token_str to token_id\n", - " self.vocab = {int(v): k for k, v in loaded_vocab.items()} # token_id: token_str\n", - " self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()} # token_str: token_id\n", + " # Convert loaded vocabulary to correct format\n", + " self.vocab = {int(v): k for k, v in loaded_vocab.items()}\n", + " self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()}\n", + "\n", + " # Handle newline character without adding a new token\n", + " if \"\\n\" not in self.inverse_vocab:\n", + " # Use an existing token ID as a placeholder for '\\n'\n", + " # Preferentially use \"<|endoftext|>\" if available\n", + " fallback_token = next((token for token in [\"<|endoftext|>\", \"Ġ\", \"\"] if token in self.inverse_vocab), None)\n", + " if fallback_token is not None:\n", + " newline_token_id = self.inverse_vocab[fallback_token]\n", + " else:\n", + " # If no fallback token is available, raise an error\n", + " raise KeyError(\"No suitable token found in vocabulary to map '\\\\n'.\")\n", + "\n", + " self.inverse_vocab[\"\\n\"] = newline_token_id\n", + " self.vocab[newline_token_id] = \"\\n\"\n", "\n", " # Load BPE merges\n", " with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n", @@ -487,17 +501,15 @@ "\n", " for rank, line in enumerate(lines):\n", " pair = tuple(line.strip().split())\n", - " if len(pair) != 2:\n", - " print(f\"Line {rank+1} has more than 2 entries: {line.strip()}\")\n", - " continue\n", - " token1, token2 = pair\n", - " if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n", - " token_id1 = self.inverse_vocab[token1]\n", - " token_id2 = self.inverse_vocab[token2]\n", - " merged_token = token1 + token2\n", - " if merged_token in self.inverse_vocab:\n", - " merged_token_id = self.inverse_vocab[merged_token]\n", - " self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n", + " if len(pair) == 2:\n", + " token1, token2 = pair\n", + " if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n", + " token_id1 = self.inverse_vocab[token1]\n", + " token_id2 = self.inverse_vocab[token2]\n", + " merged_token = token1 + token2\n", + " if merged_token in self.inverse_vocab:\n", + " merged_token_id = self.inverse_vocab[merged_token]\n", + " self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n", " # print(f\"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})\")\n", " else:\n", " print(f\"Merged token '{merged_token}' not found in vocab. Skipping.\")\n", @@ -515,21 +527,27 @@ " List[int]: The list of token IDs.\n", " \"\"\"\n", " tokens = []\n", - " # Split text into tokens, keeping newlines intact\n", - " words = text.replace(\"\\n\", \" \\n \").split() # Ensure '\\n' is treated as a separate token\n", - "\n", - " for i, word in enumerate(words):\n", - " if i > 0 and not word.startswith(\"\\n\"):\n", - " tokens.append(\"Ġ\" + word) # Add 'Ġ' to words that follow a space or newline\n", - " else:\n", - " tokens.append(word) # Handle first word or standalone '\\n'\n", + " # First split on newlines to preserve them\n", + " lines = text.split(\"\\n\")\n", + " for i, line in enumerate(lines):\n", + " if i > 0:\n", + " tokens.append(\"\\n\") # Add newline token separately\n", + " words = line.split()\n", + " for j, word in enumerate(words):\n", + " if j == 0:\n", + " if i > 0: # Start of a new line but not the first line\n", + " tokens.append(\"Ġ\" + word) # Ensure it's marked as a new segment\n", + " else:\n", + " tokens.append(word)\n", + " else:\n", + " # Prefix words in the middle of a line with 'Ġ'\n", + " tokens.append(\"Ġ\" + word)\n", "\n", " token_ids = []\n", " for token in tokens:\n", " if token in self.inverse_vocab:\n", " # token is contained in the vocabulary as is\n", - " token_id = self.inverse_vocab[token]\n", - " token_ids.append(token_id)\n", + " token_ids.append(self.inverse_vocab[token])\n", " else:\n", " # Attempt to handle subword tokenization via BPE\n", " sub_token_ids = self.tokenize_with_bpe(token)\n", @@ -587,12 +605,15 @@ " str: The decoded string.\n", " \"\"\"\n", " decoded_string = \"\"\n", - " for token_id in token_ids:\n", + " for i, token_id in enumerate(token_ids):\n", " if token_id not in self.vocab:\n", " raise ValueError(f\"Token ID {token_id} not found in vocab.\")\n", " token = self.vocab[token_id]\n", - " if token.startswith(\"Ġ\"):\n", - " # Replace 'Ġ' with a space\n", + " if token == \"\\n\":\n", + " if decoded_string and not decoded_string.endswith(\" \"):\n", + " decoded_string += \" \" # Add space if not present before a newline\n", + " decoded_string += token\n", + " elif token.startswith(\"Ġ\"):\n", " decoded_string += \" \" + token[1:]\n", " else:\n", " decoded_string += token\n", @@ -634,8 +655,8 @@ " with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n", " merges_list = json.load(file)\n", " for merge in merges_list:\n", - " pair = tuple(merge['pair'])\n", - " new_id = merge['new_id']\n", + " pair = tuple(merge[\"pair\"])\n", + " new_id = merge[\"new_id\"]\n", " self.bpe_merges[pair] = new_id\n", "\n", " @lru_cache(maxsize=None)\n", @@ -714,7 +735,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 5, "id": "4d197cad-ed10-4a42-b01c-a763859781fb", "metadata": {}, "outputs": [], @@ -745,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 6, "id": "027348fd-d52f-4396-93dd-38eed142df9b", "metadata": {}, "outputs": [], @@ -764,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 7, "id": "f705a283-355e-4460-b940-06bbc2ae4e61", "metadata": {}, "outputs": [ @@ -791,7 +812,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 8, "id": "3da42d1c-f75c-4ba7-a6c5-4cb8543d4a44", "metadata": {}, "outputs": [ @@ -825,7 +846,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 9, "id": "e1db5cce-e015-412b-ad56-060b8b638078", "metadata": {}, "outputs": [ @@ -845,7 +866,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 10, "id": "1ed1b344-f7d4-4e9e-ac34-2a04b5c5b7a8", "metadata": {}, "outputs": [ @@ -881,7 +902,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 11, "id": "da0e1faf-1933-43d9-b681-916c282a8f86", "metadata": {}, "outputs": [ @@ -899,7 +920,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 12, "id": "8b690e83-5d6b-409a-804e-321c287c24a4", "metadata": {}, "outputs": [ @@ -925,7 +946,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 13, "id": "2b9e6289-92cb-4d88-b3c8-e836d7c8095f", "metadata": {}, "outputs": [ @@ -979,7 +1000,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 14, "id": "c7056cb1-a9a3-4cf6-8364-29fb493ae240", "metadata": {}, "outputs": [ @@ -989,13 +1010,38 @@ "'This is some text.'" ] }, - "execution_count": 87, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tokenizer.decode(tokenizer.encode(\"This is some text.\"))" + "tokenizer.decode(\n", + " tokenizer.encode(\"This is some text.\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "37bc6753-8f35-4ec7-b23e-df4a12103cb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'This is some text with \\n newline characters.'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(\n", + " tokenizer.encode(\"This is some text with \\n newline characters.\")\n", + ")" ] }, { @@ -1016,7 +1062,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 16, "id": "955181cb-0910-4c6a-9c22-d8292a3ec1fc", "metadata": {}, "outputs": [], @@ -1027,7 +1073,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 17, "id": "6e5ccfe7-ac67-42f3-b727-87886a8867f1", "metadata": {}, "outputs": [], @@ -1047,7 +1093,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 18, "id": "00d9bf8f-756f-48bf-81b8-b890e2c2ef13", "metadata": {}, "outputs": [ @@ -1063,6 +1109,29 @@ "print(tokenizer2.decode(token_ids))" ] }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e7addb64-2892-4e1c-85dd-4f5152740099", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'This is some text with \\n newline characters.'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer2.decode(\n", + " tokenizer2.encode(\"This is some text with \\n newline characters.\")\n", + ")" + ] + }, { "cell_type": "markdown", "id": "b24d10b2-1ab8-44ee-b51a-14248e30d662", @@ -1082,7 +1151,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 20, "id": "b45b4366-2c2b-4309-9a14-febf3add8512", "metadata": {}, "outputs": [ @@ -1090,8 +1159,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "vocab.bpe already exists\n", - "encoder.json already exists\n" + "vocab.bpe already exists in ../02_bonus_bytepair-encoder/gpt2_model/vocab.bpe\n", + "encoder.json already exists in ../02_bonus_bytepair-encoder/gpt2_model/encoder.json\n" ] } ], @@ -1139,7 +1208,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 21, "id": "74306e6c-47d3-45a3-9e0f-93f7303ef601", "metadata": {}, "outputs": [], @@ -1160,7 +1229,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 22, "id": "2bb722b4-dbf5-4a0c-9120-efda3293f132", "metadata": {}, "outputs": [ @@ -1170,7 +1239,7 @@ "50257" ] }, - "execution_count": 93, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -1189,7 +1258,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 23, "id": "e4866de7-fb32-4dd6-a878-469ec734641c", "metadata": {}, "outputs": [ @@ -1209,7 +1278,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 24, "id": "3da8d9b2-af55-4b09-95d7-fabd983e919e", "metadata": {}, "outputs": [ @@ -1225,30 +1294,6 @@ "print(tokenizer_gpt2.decode(token_ids))" ] }, - { - "cell_type": "code", - "execution_count": 99, - "id": "460deb85-8de7-40c7-ba18-3c17831fa8ab", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[1212, 318, 617, 2420]" - ] - }, - "execution_count": 99, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import tiktoken\n", - "\n", - "tik_tokenizer = tiktoken.get_encoding(\"gpt2\")\n", - "tik_tokenizer.encode(input_text)" - ] - }, { "cell_type": "markdown", "id": "b3b1e2dc-f69b-4533-87ef-549e6fb9b5a0", @@ -1303,7 +1348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" } }, "nbformat": 4,