add special token handling to bpe from scratch code (#616)

This commit is contained in:
Sebastian Raschka 2025-04-13 12:38:22 -05:00 committed by GitHub
parent d5eaa36416
commit 48e98abc8e

View File

@ -513,43 +513,71 @@
" else:\n",
" print(f\"Skipping pair {pair} as one token is not in the vocabulary.\")\n",
"\n",
" def encode(self, text):\n",
" def encode(self, text, allowed_special=None):\n",
" \"\"\"\n",
" Encode the input text into a list of token IDs.\n",
"\n",
" Encode the input text into a list of token IDs, with tiktoken-style handling of special tokens.\n",
" \n",
" Args:\n",
" text (str): The text to encode.\n",
"\n",
" text (str): The input text to encode.\n",
" allowed_special (set or None): Special tokens to allow passthrough. If None, special handling is disabled.\n",
" \n",
" Returns:\n",
" List[int]: The list of token IDs.\n",
" List of token IDs.\n",
" \"\"\"\n",
" import re\n",
" \n",
" token_ids = []\n",
" \n",
" # If special token handling is enabled\n",
" if allowed_special is not None and len(allowed_special) > 0:\n",
" # Build regex to match allowed special tokens\n",
" special_pattern = (\n",
" \"(\" + \"|\".join(re.escape(tok) for tok in sorted(allowed_special, key=len, reverse=True)) + \")\"\n",
" )\n",
" \n",
" last_index = 0\n",
" for match in re.finditer(special_pattern, text):\n",
" prefix = text[last_index:match.start()]\n",
" token_ids.extend(self.encode(prefix, allowed_special=None)) # Encode prefix without special handling\n",
" \n",
" special_token = match.group(0)\n",
" if special_token in self.inverse_vocab:\n",
" token_ids.append(self.inverse_vocab[special_token])\n",
" else:\n",
" raise ValueError(f\"Special token {special_token} not found in vocabulary.\")\n",
" last_index = match.end()\n",
" \n",
" text = text[last_index:] # Remaining part to process normally\n",
" \n",
" # Check if any disallowed special tokens are in the remainder\n",
" disallowed = [\n",
" tok for tok in self.inverse_vocab\n",
" if tok.startswith(\"<|\") and tok.endswith(\"|>\") and tok in text and tok not in allowed_special\n",
" ]\n",
" if disallowed:\n",
" raise ValueError(f\"Disallowed special tokens encountered in text: {disallowed}\")\n",
" \n",
" # If no special tokens, or remaining text after special token split:\n",
" tokens = []\n",
" # First split on newlines to preserve them\n",
" lines = text.split(\"\\n\")\n",
" for i, line in enumerate(lines):\n",
" if i > 0:\n",
" tokens.append(\"\\n\") # Add newline token separately\n",
" tokens.append(\"\\n\")\n",
" words = line.split()\n",
" for j, word in enumerate(words):\n",
" if j == 0:\n",
" if i > 0: # Start of a new line but not the first line\n",
" tokens.append(\"Ġ\" + word) # Ensure it's marked as a new segment\n",
" else:\n",
" tokens.append(word)\n",
" else:\n",
" # Prefix words in the middle of a line with \"Ġ\"\n",
" if j == 0 and i > 0:\n",
" tokens.append(\"Ġ\" + word)\n",
"\n",
" token_ids = []\n",
" elif j == 0:\n",
" tokens.append(word)\n",
" else:\n",
" tokens.append(\"Ġ\" + word)\n",
" \n",
" for token in tokens:\n",
" if token in self.inverse_vocab:\n",
" # token is contained in the vocabulary as is\n",
" token_ids.append(self.inverse_vocab[token])\n",
" else:\n",
" # Attempt to handle subword tokenization via BPE\n",
" sub_token_ids = self.tokenize_with_bpe(token)\n",
" token_ids.extend(sub_token_ids)\n",
"\n",
" token_ids.extend(self.tokenize_with_bpe(token))\n",
" \n",
" return token_ids\n",
"\n",
" def tokenize_with_bpe(self, token):\n",
@ -781,7 +809,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 5,
"id": "51872c08-e01b-40c3-a8a0-e8d6a773e3df",
"metadata": {},
"outputs": [
@ -940,6 +968,46 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "78249752-38d7-47b9-b259-912bcc093dc4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46, 256, 60, 124, 271, 683, 102, 116, 461, 116, 124, 62]\n"
]
}
],
"source": [
"input_text = \"Jack embraced beauty through art and life. <|endoftext|> \"\n",
"token_ids = tokenizer.encode(input_text)\n",
"print(token_ids)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0331d37d-49a3-44f7-9aa9-9834e0938741",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46, 257]\n"
]
}
],
"source": [
"input_text = \"Jack embraced beauty through art and life. <|endoftext|> \"\n",
"token_ids = tokenizer.encode(input_text, allowed_special={\"<|endoftext|>\"})\n",
"print(token_ids)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1ed1b344-f7d4-4e9e-ac34-2a04b5c5b7a8",
"metadata": {},
"outputs": [
@ -947,8 +1015,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Number of characters: 42\n",
"Number of token IDs: 20\n"
"Number of characters: 57\n",
"Number of token IDs: 21\n"
]
}
],
@ -975,7 +1043,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"id": "da0e1faf-1933-43d9-b681-916c282a8f86",
"metadata": {},
"outputs": [
@ -983,7 +1051,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46]\n"
"[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46, 257]\n"
]
}
],
@ -993,7 +1061,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "8b690e83-5d6b-409a-804e-321c287c24a4",
"metadata": {},
"outputs": [
@ -1001,7 +1069,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Jack embraced beauty through art and life.\n"
"Jack embraced beauty through art and life.<|endoftext|>\n"
]
}
],
@ -1019,7 +1087,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "2b9e6289-92cb-4d88-b3c8-e836d7c8095f",
"metadata": {},
"outputs": [
@ -1046,7 +1114,8 @@
"256 -> \n",
"326 -> li\n",
"972 -> fe\n",
"46 -> .\n"
"46 -> .\n",
"257 -> <|endoftext|>\n"
]
}
],
@ -1073,7 +1142,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "c7056cb1-a9a3-4cf6-8364-29fb493ae240",
"metadata": {},
"outputs": [
@ -1083,7 +1152,7 @@
"'This is some text.'"
]
},
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@ -1096,7 +1165,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"id": "37bc6753-8f35-4ec7-b23e-df4a12103cb4",
"metadata": {},
"outputs": [
@ -1106,7 +1175,7 @@
"'This is some text with \\n newline characters.'"
]
},
"execution_count": 15,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@ -1135,7 +1204,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "955181cb-0910-4c6a-9c22-d8292a3ec1fc",
"metadata": {},
"outputs": [],
@ -1146,7 +1215,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "6e5ccfe7-ac67-42f3-b727-87886a8867f1",
"metadata": {},
"outputs": [],
@ -1166,7 +1235,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"id": "00d9bf8f-756f-48bf-81b8-b890e2c2ef13",
"metadata": {},
"outputs": [
@ -1174,7 +1243,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Jack embraced beauty through art and life.\n"
"Jack embraced beauty through art and life.<|endoftext|>\n"
]
}
],
@ -1184,7 +1253,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 21,
"id": "e7addb64-2892-4e1c-85dd-4f5152740099",
"metadata": {},
"outputs": [
@ -1194,7 +1263,7 @@
"'This is some text with \\n newline characters.'"
]
},
"execution_count": 19,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@ -1224,7 +1293,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 22,
"id": "b45b4366-2c2b-4309-9a14-febf3add8512",
"metadata": {},
"outputs": [
@ -1264,7 +1333,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"id": "74306e6c-47d3-45a3-9e0f-93f7303ef601",
"metadata": {},
"outputs": [],
@ -1285,7 +1354,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 24,
"id": "2bb722b4-dbf5-4a0c-9120-efda3293f132",
"metadata": {},
"outputs": [
@ -1295,7 +1364,7 @@
"50257"
]
},
"execution_count": 22,
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@ -1314,7 +1383,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 25,
"id": "e4866de7-fb32-4dd6-a878-469ec734641c",
"metadata": {},
"outputs": [
@ -1334,7 +1403,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 26,
"id": "3da8d9b2-af55-4b09-95d7-fabd983e919e",
"metadata": {},
"outputs": [