diff --git a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb index 2e8981d..18d3d91 100644 --- a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb +++ b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb @@ -513,43 +513,71 @@ " else:\n", " print(f\"Skipping pair {pair} as one token is not in the vocabulary.\")\n", "\n", - " def encode(self, text):\n", + " def encode(self, text, allowed_special=None):\n", " \"\"\"\n", - " Encode the input text into a list of token IDs.\n", - "\n", + " Encode the input text into a list of token IDs, with tiktoken-style handling of special tokens.\n", + " \n", " Args:\n", - " text (str): The text to encode.\n", - "\n", + " text (str): The input text to encode.\n", + " allowed_special (set or None): Special tokens to allow passthrough. If None, special handling is disabled.\n", + " \n", " Returns:\n", - " List[int]: The list of token IDs.\n", + " List of token IDs.\n", " \"\"\"\n", + " import re\n", + " \n", + " token_ids = []\n", + " \n", + " # If special token handling is enabled\n", + " if allowed_special is not None and len(allowed_special) > 0:\n", + " # Build regex to match allowed special tokens\n", + " special_pattern = (\n", + " \"(\" + \"|\".join(re.escape(tok) for tok in sorted(allowed_special, key=len, reverse=True)) + \")\"\n", + " )\n", + " \n", + " last_index = 0\n", + " for match in re.finditer(special_pattern, text):\n", + " prefix = text[last_index:match.start()]\n", + " token_ids.extend(self.encode(prefix, allowed_special=None)) # Encode prefix without special handling\n", + " \n", + " special_token = match.group(0)\n", + " if special_token in self.inverse_vocab:\n", + " token_ids.append(self.inverse_vocab[special_token])\n", + " else:\n", + " raise ValueError(f\"Special token {special_token} not found in vocabulary.\")\n", + " last_index = match.end()\n", + " \n", + " text = text[last_index:] # Remaining part to process normally\n", + " \n", + " # Check if any disallowed special tokens are in the remainder\n", + " disallowed = [\n", + " tok for tok in self.inverse_vocab\n", + " if tok.startswith(\"<|\") and tok.endswith(\"|>\") and tok in text and tok not in allowed_special\n", + " ]\n", + " if disallowed:\n", + " raise ValueError(f\"Disallowed special tokens encountered in text: {disallowed}\")\n", + " \n", + " # If no special tokens, or remaining text after special token split:\n", " tokens = []\n", - " # First split on newlines to preserve them\n", " lines = text.split(\"\\n\")\n", " for i, line in enumerate(lines):\n", " if i > 0:\n", - " tokens.append(\"\\n\") # Add newline token separately\n", + " tokens.append(\"\\n\")\n", " words = line.split()\n", " for j, word in enumerate(words):\n", - " if j == 0:\n", - " if i > 0: # Start of a new line but not the first line\n", - " tokens.append(\"Ġ\" + word) # Ensure it's marked as a new segment\n", - " else:\n", - " tokens.append(word)\n", - " else:\n", - " # Prefix words in the middle of a line with \"Ġ\"\n", + " if j == 0 and i > 0:\n", " tokens.append(\"Ġ\" + word)\n", - "\n", - " token_ids = []\n", + " elif j == 0:\n", + " tokens.append(word)\n", + " else:\n", + " tokens.append(\"Ġ\" + word)\n", + " \n", " for token in tokens:\n", " if token in self.inverse_vocab:\n", - " # token is contained in the vocabulary as is\n", " token_ids.append(self.inverse_vocab[token])\n", " else:\n", - " # Attempt to handle subword tokenization via BPE\n", - " sub_token_ids = self.tokenize_with_bpe(token)\n", - " token_ids.extend(sub_token_ids)\n", - "\n", + " token_ids.extend(self.tokenize_with_bpe(token))\n", + " \n", " return token_ids\n", "\n", " def tokenize_with_bpe(self, token):\n", @@ -781,7 +809,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "id": "51872c08-e01b-40c3-a8a0-e8d6a773e3df", "metadata": {}, "outputs": [ @@ -940,6 +968,46 @@ { "cell_type": "code", "execution_count": 10, + "id": "78249752-38d7-47b9-b259-912bcc093dc4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46, 256, 60, 124, 271, 683, 102, 116, 461, 116, 124, 62]\n" + ] + } + ], + "source": [ + "input_text = \"Jack embraced beauty through art and life. <|endoftext|> \"\n", + "token_ids = tokenizer.encode(input_text)\n", + "print(token_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0331d37d-49a3-44f7-9aa9-9834e0938741", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46, 257]\n" + ] + } + ], + "source": [ + "input_text = \"Jack embraced beauty through art and life. <|endoftext|> \"\n", + "token_ids = tokenizer.encode(input_text, allowed_special={\"<|endoftext|>\"})\n", + "print(token_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "id": "1ed1b344-f7d4-4e9e-ac34-2a04b5c5b7a8", "metadata": {}, "outputs": [ @@ -947,8 +1015,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Number of characters: 42\n", - "Number of token IDs: 20\n" + "Number of characters: 57\n", + "Number of token IDs: 21\n" ] } ], @@ -975,7 +1043,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "id": "da0e1faf-1933-43d9-b681-916c282a8f86", "metadata": {}, "outputs": [ @@ -983,7 +1051,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46]\n" + "[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46, 257]\n" ] } ], @@ -993,7 +1061,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "8b690e83-5d6b-409a-804e-321c287c24a4", "metadata": {}, "outputs": [ @@ -1001,7 +1069,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Jack embraced beauty through art and life.\n" + "Jack embraced beauty through art and life.<|endoftext|>\n" ] } ], @@ -1019,7 +1087,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "2b9e6289-92cb-4d88-b3c8-e836d7c8095f", "metadata": {}, "outputs": [ @@ -1046,7 +1114,8 @@ "256 -> \n", "326 -> li\n", "972 -> fe\n", - "46 -> .\n" + "46 -> .\n", + "257 -> <|endoftext|>\n" ] } ], @@ -1073,7 +1142,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "c7056cb1-a9a3-4cf6-8364-29fb493ae240", "metadata": {}, "outputs": [ @@ -1083,7 +1152,7 @@ "'This is some text.'" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -1096,7 +1165,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "id": "37bc6753-8f35-4ec7-b23e-df4a12103cb4", "metadata": {}, "outputs": [ @@ -1106,7 +1175,7 @@ "'This is some text with \\n newline characters.'" ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -1135,7 +1204,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "id": "955181cb-0910-4c6a-9c22-d8292a3ec1fc", "metadata": {}, "outputs": [], @@ -1146,7 +1215,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "id": "6e5ccfe7-ac67-42f3-b727-87886a8867f1", "metadata": {}, "outputs": [], @@ -1166,7 +1235,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "id": "00d9bf8f-756f-48bf-81b8-b890e2c2ef13", "metadata": {}, "outputs": [ @@ -1174,7 +1243,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Jack embraced beauty through art and life.\n" + "Jack embraced beauty through art and life.<|endoftext|>\n" ] } ], @@ -1184,7 +1253,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "e7addb64-2892-4e1c-85dd-4f5152740099", "metadata": {}, "outputs": [ @@ -1194,7 +1263,7 @@ "'This is some text with \\n newline characters.'" ] }, - "execution_count": 19, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -1224,7 +1293,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "id": "b45b4366-2c2b-4309-9a14-febf3add8512", "metadata": {}, "outputs": [ @@ -1264,7 +1333,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "id": "74306e6c-47d3-45a3-9e0f-93f7303ef601", "metadata": {}, "outputs": [], @@ -1285,7 +1354,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "id": "2bb722b4-dbf5-4a0c-9120-efda3293f132", "metadata": {}, "outputs": [ @@ -1295,7 +1364,7 @@ "50257" ] }, - "execution_count": 22, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1314,7 +1383,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "id": "e4866de7-fb32-4dd6-a878-469ec734641c", "metadata": {}, "outputs": [ @@ -1334,7 +1403,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "id": "3da8d9b2-af55-4b09-95d7-fabd983e919e", "metadata": {}, "outputs": [