mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	Fix BPE bonus materials (#561)
* Fix BPE bonus materials * fix bpe implementation * update * Add 'Hello, world. Is this-- a test?' test case * update link to test file * update path handling * update path handling * fix pytest paths
This commit is contained in:
		
							parent
							
								
									96ca2fcb2f
								
							
						
					
					
						commit
						f63f04d8d5
					
				
							
								
								
									
										6
									
								
								.github/workflows/basic-tests-linux-uv.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/basic-tests-linux-uv.yml
									
									
									
									
										vendored
									
									
								
							@ -60,3 +60,9 @@ jobs:
 | 
			
		||||
          pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
 | 
			
		||||
          pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
 | 
			
		||||
          pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
 | 
			
		||||
 | 
			
		||||
      - name: Test Selected Bonus Materials
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          source .venv/bin/activate
 | 
			
		||||
          pytest ch02/05_bpe-from-scratch/tests/tests.py
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -1,3 +1,4 @@
 | 
			
		||||
 | 
			
		||||
# Configs and keys
 | 
			
		||||
ch05/07_gpt_to_llama/config.json
 | 
			
		||||
ch07/02_dataset-utilities/config.json
 | 
			
		||||
@ -63,6 +64,8 @@ ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth
 | 
			
		||||
ch07/01_main-chapter-code/gpt2/
 | 
			
		||||
 | 
			
		||||
# Datasets
 | 
			
		||||
the-verdict.txt
 | 
			
		||||
 | 
			
		||||
appendix-E/01_main-chapter-code/sms_spam_collection.zip
 | 
			
		||||
appendix-E/01_main-chapter-code/sms_spam_collection
 | 
			
		||||
appendix-E/01_main-chapter-code/train.csv
 | 
			
		||||
@ -70,6 +73,7 @@ appendix-E/01_main-chapter-code/test.csv
 | 
			
		||||
appendix-E/01_main-chapter-code/validation.csv
 | 
			
		||||
 | 
			
		||||
ch02/01_main-chapter-code/number-data.txt
 | 
			
		||||
ch02/05_bpe-from-scratch/the-verdict.txt
 | 
			
		||||
 | 
			
		||||
ch05/03_bonus_pretraining_on_gutenberg/gutenberg
 | 
			
		||||
ch05/03_bonus_pretraining_on_gutenberg/gutenberg_preprocessed
 | 
			
		||||
@ -107,7 +111,9 @@ ch02/05_bpe-from-scratch/bpe_merges.txt
 | 
			
		||||
ch02/05_bpe-from-scratch/encoder.json
 | 
			
		||||
ch02/05_bpe-from-scratch/vocab.bpe
 | 
			
		||||
ch02/05_bpe-from-scratch/vocab.json
 | 
			
		||||
 | 
			
		||||
encoder.json
 | 
			
		||||
vocab.bpe
 | 
			
		||||
vocab.json
 | 
			
		||||
 | 
			
		||||
# Other
 | 
			
		||||
ch0?/0?_user_interface/.chainlit/
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "tiktoken version: 0.7.0\n"
 | 
			
		||||
      "tiktoken version: 0.9.0\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -180,8 +180,8 @@
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Fetching encoder.json: 1.04Mit [00:00, 4.13Mit/s]                                                   \n",
 | 
			
		||||
      "Fetching vocab.bpe: 457kit [00:00, 2.56Mit/s]                                                       \n"
 | 
			
		||||
      "Fetching encoder.json: 1.04Mit [00:00, 3.69Mit/s]                                                   \n",
 | 
			
		||||
      "Fetching vocab.bpe: 457kit [00:00, 2.53Mit/s]                                                       \n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -256,10 +256,18 @@
 | 
			
		||||
   "id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
 | 
			
		||||
      "  from .autonotebook import tqdm as notebook_tqdm\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "data": {
 | 
			
		||||
      "text/plain": [
 | 
			
		||||
       "'4.48.0'"
 | 
			
		||||
       "'4.49.0'"
 | 
			
		||||
      ]
 | 
			
		||||
     },
 | 
			
		||||
     "execution_count": 12,
 | 
			
		||||
@ -423,7 +431,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "[1544, 18798, 11, 995, 13, 1148, 256, 5303, 82, 438, 257, 1332, 30]\n"
 | 
			
		||||
      "[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -451,7 +459,7 @@
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "with open('../01_main-chapter-code/the-verdict.txt', 'r', encoding='utf-8') as f:\n",
 | 
			
		||||
    "with open(\"../01_main-chapter-code/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
 | 
			
		||||
    "    raw_text = f.read()"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
@ -473,7 +481,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "3.39 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
      "3.84 ms ± 9.83 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -499,7 +507,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "1.08 ms ± 5.99 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
 | 
			
		||||
      "901 μs ± 6.27 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -532,7 +540,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "10.2 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
      "11 ms ± 94.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -550,7 +558,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "10 ms ± 36.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
      "10.8 ms ± 180 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -575,7 +583,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "3.79 ms ± 48.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
      "3.66 ms ± 3.67 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -593,7 +601,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "3.83 ms ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
      "3.77 ms ± 49.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -619,7 +627,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "1.59 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
 | 
			
		||||
      "9.37 ms ± 50.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -644,7 +652,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.11.4"
 | 
			
		||||
   "version": "3.10.16"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 | 
			
		||||
@ -382,7 +382,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "execution_count": 4,
 | 
			
		||||
   "id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
@ -401,6 +401,10 @@
 | 
			
		||||
    "        # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}\n",
 | 
			
		||||
    "        self.bpe_merges = {}\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # For the official OpenAI GPT-2 merges, use a rank dict:\n",
 | 
			
		||||
    "        #  of form {(string_A, string_B): rank}, where lower rank = higher priority\n",
 | 
			
		||||
    "        self.bpe_ranks = {}\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    def train(self, text, vocab_size, allowed_special={\"<|endoftext|>\"}):\n",
 | 
			
		||||
    "        \"\"\"\n",
 | 
			
		||||
    "        Train the BPE tokenizer from scratch.\n",
 | 
			
		||||
@ -411,7 +415,7 @@
 | 
			
		||||
    "            allowed_special (set): A set of special tokens to include.\n",
 | 
			
		||||
    "        \"\"\"\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Preprocess: Replace spaces with 'Ġ'\n",
 | 
			
		||||
    "        # Preprocess: Replace spaces with \"Ġ\"\n",
 | 
			
		||||
    "        # Note that Ġ is a particularity of the GPT-2 BPE implementation\n",
 | 
			
		||||
    "        # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
 | 
			
		||||
    "        # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
 | 
			
		||||
@ -423,18 +427,16 @@
 | 
			
		||||
    "                processed_text.append(char)\n",
 | 
			
		||||
    "        processed_text = \"\".join(processed_text)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Initialize vocab with unique characters, including 'Ġ' if present\n",
 | 
			
		||||
    "        # Initialize vocab with unique characters, including \"Ġ\" if present\n",
 | 
			
		||||
    "        # Start with the first 256 ASCII characters\n",
 | 
			
		||||
    "        unique_chars = [chr(i) for i in range(256)]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Extend unique_chars with characters from processed_text that are not already included\n",
 | 
			
		||||
    "        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Optionally, ensure 'Ġ' is included if it is relevant to your text processing\n",
 | 
			
		||||
    "        unique_chars.extend(\n",
 | 
			
		||||
    "            char for char in sorted(set(processed_text))\n",
 | 
			
		||||
    "            if char not in unique_chars\n",
 | 
			
		||||
    "        )\n",
 | 
			
		||||
    "        if \"Ġ\" not in unique_chars:\n",
 | 
			
		||||
    "            unique_chars.append(\"Ġ\")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Now create the vocab and inverse vocab dictionaries\n",
 | 
			
		||||
    "        self.vocab = {i: char for i, char in enumerate(unique_chars)}\n",
 | 
			
		||||
    "        self.inverse_vocab = {char: i for i, char in self.vocab.items()}\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
@ -452,7 +454,7 @@
 | 
			
		||||
    "        # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
 | 
			
		||||
    "        for new_id in range(len(self.vocab), vocab_size):\n",
 | 
			
		||||
    "            pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n",
 | 
			
		||||
    "            if pair_id is None:  # No more pairs to merge. Stopping training.\n",
 | 
			
		||||
    "            if pair_id is None:\n",
 | 
			
		||||
    "                break\n",
 | 
			
		||||
    "            token_ids = self.replace_pair(token_ids, pair_id, new_id)\n",
 | 
			
		||||
    "            self.bpe_merges[pair_id] = new_id\n",
 | 
			
		||||
@ -492,29 +494,24 @@
 | 
			
		||||
    "            self.inverse_vocab[\"\\n\"] = newline_token_id\n",
 | 
			
		||||
    "            self.vocab[newline_token_id] = \"\\n\"\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Load BPE merges\n",
 | 
			
		||||
    "        # Load GPT-2 merges and store them with an assigned \"rank\"\n",
 | 
			
		||||
    "        self.bpe_ranks = {}  # reset ranks\n",
 | 
			
		||||
    "        with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
 | 
			
		||||
    "            lines = file.readlines()\n",
 | 
			
		||||
    "            # Skip header line if present\n",
 | 
			
		||||
    "            if lines and lines[0].startswith(\"#\"):\n",
 | 
			
		||||
    "                lines = lines[1:]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "            rank = 0\n",
 | 
			
		||||
    "            for line in lines:\n",
 | 
			
		||||
    "                pair = tuple(line.strip().split())\n",
 | 
			
		||||
    "                if len(pair) == 2:\n",
 | 
			
		||||
    "                    token1, token2 = pair\n",
 | 
			
		||||
    "                    # If token1 or token2 not in vocab, skip\n",
 | 
			
		||||
    "                    if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
 | 
			
		||||
    "                        token_id1 = self.inverse_vocab[token1]\n",
 | 
			
		||||
    "                        token_id2 = self.inverse_vocab[token2]\n",
 | 
			
		||||
    "                        merged_token = token1 + token2\n",
 | 
			
		||||
    "                        if merged_token in self.inverse_vocab:\n",
 | 
			
		||||
    "                            merged_token_id = self.inverse_vocab[merged_token]\n",
 | 
			
		||||
    "                            self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n",
 | 
			
		||||
    "                        # print(f\"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})\")\n",
 | 
			
		||||
    "                        self.bpe_ranks[(token1, token2)] = rank\n",
 | 
			
		||||
    "                        rank += 1\n",
 | 
			
		||||
    "                    else:\n",
 | 
			
		||||
    "                            print(f\"Merged token '{merged_token}' not found in vocab. Skipping.\")\n",
 | 
			
		||||
    "                    else:\n",
 | 
			
		||||
    "                        print(f\"Skipping pair {pair} as one of the tokens is not in the vocabulary.\")\n",
 | 
			
		||||
    "                        print(f\"Skipping pair {pair} as one token is not in the vocabulary.\")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    def encode(self, text):\n",
 | 
			
		||||
    "        \"\"\"\n",
 | 
			
		||||
@ -540,7 +537,7 @@
 | 
			
		||||
    "                    else:\n",
 | 
			
		||||
    "                        tokens.append(word)\n",
 | 
			
		||||
    "                else:\n",
 | 
			
		||||
    "                    # Prefix words in the middle of a line with 'Ġ'\n",
 | 
			
		||||
    "                    # Prefix words in the middle of a line with \"Ġ\"\n",
 | 
			
		||||
    "                    tokens.append(\"Ġ\" + word)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        token_ids = []\n",
 | 
			
		||||
@ -571,6 +568,8 @@
 | 
			
		||||
    "            missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]\n",
 | 
			
		||||
    "            raise ValueError(f\"Characters not found in vocab: {missing_chars}\")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # If we haven't loaded OpenAI's GPT-2 merges, use my approach\n",
 | 
			
		||||
    "        if not self.bpe_ranks:\n",
 | 
			
		||||
    "            can_merge = True\n",
 | 
			
		||||
    "            while can_merge and len(token_ids) > 1:\n",
 | 
			
		||||
    "                can_merge = False\n",
 | 
			
		||||
@ -591,9 +590,53 @@
 | 
			
		||||
    "                if i < len(token_ids):\n",
 | 
			
		||||
    "                    new_tokens.append(token_ids[i])\n",
 | 
			
		||||
    "                token_ids = new_tokens\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "            return token_ids\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Otherwise, do GPT-2-style merging with the ranks:\n",
 | 
			
		||||
    "        # 1) Convert token_ids back to string \"symbols\" for each ID\n",
 | 
			
		||||
    "        symbols = [self.vocab[id_num] for id_num in token_ids]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Repeatedly merge all occurrences of the lowest-rank pair\n",
 | 
			
		||||
    "        while True:\n",
 | 
			
		||||
    "            # Collect all adjacent pairs\n",
 | 
			
		||||
    "            pairs = set(zip(symbols, symbols[1:]))\n",
 | 
			
		||||
    "            if not pairs:\n",
 | 
			
		||||
    "                break\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "            # Find the pair with the best (lowest) rank\n",
 | 
			
		||||
    "            min_rank = 1_000_000_000\n",
 | 
			
		||||
    "            bigram = None\n",
 | 
			
		||||
    "            for p in pairs:\n",
 | 
			
		||||
    "                r = self.bpe_ranks.get(p, 1_000_000_000)\n",
 | 
			
		||||
    "                if r < min_rank:\n",
 | 
			
		||||
    "                    min_rank = r\n",
 | 
			
		||||
    "                    bigram = p\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "            # If no valid ranked pair is present, we're done\n",
 | 
			
		||||
    "            if bigram is None or bigram not in self.bpe_ranks:\n",
 | 
			
		||||
    "                break\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "            # Merge all occurrences of that pair\n",
 | 
			
		||||
    "            first, second = bigram\n",
 | 
			
		||||
    "            new_symbols = []\n",
 | 
			
		||||
    "            i = 0\n",
 | 
			
		||||
    "            while i < len(symbols):\n",
 | 
			
		||||
    "                # If we see (first, second) at position i, merge them\n",
 | 
			
		||||
    "                if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:\n",
 | 
			
		||||
    "                    new_symbols.append(first + second)  # merged symbol\n",
 | 
			
		||||
    "                    i += 2\n",
 | 
			
		||||
    "                else:\n",
 | 
			
		||||
    "                    new_symbols.append(symbols[i])\n",
 | 
			
		||||
    "                    i += 1\n",
 | 
			
		||||
    "            symbols = new_symbols\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "            if len(symbols) == 1:\n",
 | 
			
		||||
    "                break\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "        # Finally, convert merged symbols back to IDs\n",
 | 
			
		||||
    "        merged_ids = [self.inverse_vocab[sym] for sym in symbols]\n",
 | 
			
		||||
    "        return merged_ids\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    def decode(self, token_ids):\n",
 | 
			
		||||
    "        \"\"\"\n",
 | 
			
		||||
    "        Decode a list of token IDs back into a string.\n",
 | 
			
		||||
@ -738,22 +781,49 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 5,
 | 
			
		||||
   "id": "4d197cad-ed10-4a42-b01c-a763859781fb",
 | 
			
		||||
   "execution_count": 25,
 | 
			
		||||
   "id": "51872c08-e01b-40c3-a8a0-e8d6a773e3df",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "the-verdict.txt already exists in ./the-verdict.txt\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os\n",
 | 
			
		||||
    "import urllib.request\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "if not os.path.exists(\"../01_main-chapter-code/the-verdict.txt\"):\n",
 | 
			
		||||
    "    url = (\"https://raw.githubusercontent.com/rasbt/\"\n",
 | 
			
		||||
    "           \"LLMs-from-scratch/main/ch02/01_main-chapter-code/\"\n",
 | 
			
		||||
    "           \"the-verdict.txt\")\n",
 | 
			
		||||
    "    file_path = \"../01_main-chapter-code/the-verdict.txt\"\n",
 | 
			
		||||
    "    urllib.request.urlretrieve(url, file_path)\n",
 | 
			
		||||
    "def download_file_if_absent(url, filename, search_dirs):\n",
 | 
			
		||||
    "    for directory in search_dirs:\n",
 | 
			
		||||
    "        file_path = os.path.join(directory, filename)\n",
 | 
			
		||||
    "        if os.path.exists(file_path):\n",
 | 
			
		||||
    "            print(f\"{filename} already exists in {file_path}\")\n",
 | 
			
		||||
    "            return file_path\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "with open(\"../01_main-chapter-code/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f: # added ../01_main-chapter-code/\n",
 | 
			
		||||
    "    target_path = os.path.join(search_dirs[0], filename)\n",
 | 
			
		||||
    "    try:\n",
 | 
			
		||||
    "        with urllib.request.urlopen(url) as response, open(target_path, \"wb\") as out_file:\n",
 | 
			
		||||
    "            out_file.write(response.read())\n",
 | 
			
		||||
    "        print(f\"Downloaded {filename} to {target_path}\")\n",
 | 
			
		||||
    "    except Exception as e:\n",
 | 
			
		||||
    "        print(f\"Failed to download {filename}. Error: {e}\")\n",
 | 
			
		||||
    "    return target_path\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "verdict_path = download_file_if_absent(\n",
 | 
			
		||||
    "    url=(\n",
 | 
			
		||||
    "         \"https://raw.githubusercontent.com/rasbt/\"\n",
 | 
			
		||||
    "         \"LLMs-from-scratch/main/ch02/01_main-chapter-code/\"\n",
 | 
			
		||||
    "         \"the-verdict.txt\"\n",
 | 
			
		||||
    "    ),\n",
 | 
			
		||||
    "    filename=\"the-verdict.txt\",\n",
 | 
			
		||||
    "    search_dirs=\".\"\n",
 | 
			
		||||
    ")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "with open(verdict_path, \"r\", encoding=\"utf-8\") as f: # added ../01_main-chapter-code/\n",
 | 
			
		||||
    "    text = f.read()"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
@ -1168,24 +1238,7 @@
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os\n",
 | 
			
		||||
    "import urllib.request\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def download_file_if_absent(url, filename, search_dirs):\n",
 | 
			
		||||
    "    for directory in search_dirs:\n",
 | 
			
		||||
    "        file_path = os.path.join(directory, filename)\n",
 | 
			
		||||
    "        if os.path.exists(file_path):\n",
 | 
			
		||||
    "            print(f\"{filename} already exists in {file_path}\")\n",
 | 
			
		||||
    "            return file_path\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    target_path = os.path.join(search_dirs[0], filename)\n",
 | 
			
		||||
    "    try:\n",
 | 
			
		||||
    "        with urllib.request.urlopen(url) as response, open(target_path, \"wb\") as out_file:\n",
 | 
			
		||||
    "            out_file.write(response.read())\n",
 | 
			
		||||
    "        print(f\"Downloaded {filename} to {target_path}\")\n",
 | 
			
		||||
    "    except Exception as e:\n",
 | 
			
		||||
    "        print(f\"Failed to download {filename}. Error: {e}\")\n",
 | 
			
		||||
    "    return target_path\n",
 | 
			
		||||
    "# Download files if not already present in this directory\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Define the directories to search and the files to download\n",
 | 
			
		||||
    "search_directories = [\".\", \"../02_bonus_bytepair-encoder/gpt2_model/\"]\n",
 | 
			
		||||
@ -1351,7 +1404,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
   "version": "3.10.16"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										147
									
								
								ch02/05_bpe-from-scratch/tests/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								ch02/05_bpe-from-scratch/tests/tests.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,147 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import io
 | 
			
		||||
import nbformat
 | 
			
		||||
import types
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
import tiktoken
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def import_definitions_from_notebook(fullname, names):
 | 
			
		||||
    """Loads function definitions from a Jupyter notebook file into a module."""
 | 
			
		||||
    path = os.path.join(os.path.dirname(__file__), "..", fullname + ".ipynb")
 | 
			
		||||
    path = os.path.normpath(path)
 | 
			
		||||
 | 
			
		||||
    if not os.path.exists(path):
 | 
			
		||||
        raise FileNotFoundError(f"Notebook file not found at: {path}")
 | 
			
		||||
 | 
			
		||||
    with io.open(path, "r", encoding="utf-8") as f:
 | 
			
		||||
        nb = nbformat.read(f, as_version=4)
 | 
			
		||||
 | 
			
		||||
    mod = types.ModuleType(fullname)
 | 
			
		||||
    sys.modules[fullname] = mod
 | 
			
		||||
 | 
			
		||||
    # Execute all code cells to capture dependencies
 | 
			
		||||
    for cell in nb.cells:
 | 
			
		||||
        if cell.cell_type == "code":
 | 
			
		||||
            exec(cell.source, mod.__dict__)
 | 
			
		||||
 | 
			
		||||
    # Ensure required names are in module
 | 
			
		||||
    missing_names = [name for name in names if name not in mod.__dict__]
 | 
			
		||||
    if missing_names:
 | 
			
		||||
        raise ImportError(f"Missing definitions in notebook: {missing_names}")
 | 
			
		||||
 | 
			
		||||
    return mod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="module")
 | 
			
		||||
def imported_module():
 | 
			
		||||
    fullname = "bpe-from-scratch"
 | 
			
		||||
    names = ["BPETokenizerSimple", "download_file_if_absent"]
 | 
			
		||||
    return import_definitions_from_notebook(fullname, names)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="module")
 | 
			
		||||
def gpt2_files(imported_module):
 | 
			
		||||
    """Fixture to handle downloading GPT-2 files."""
 | 
			
		||||
    download_file_if_absent = getattr(imported_module, "download_file_if_absent", None)
 | 
			
		||||
 | 
			
		||||
    search_directories = [".", "../02_bonus_bytepair-encoder/gpt2_model/"]
 | 
			
		||||
    files_to_download = {
 | 
			
		||||
        "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe": "vocab.bpe",
 | 
			
		||||
        "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json": "encoder.json"
 | 
			
		||||
    }
 | 
			
		||||
    paths = {filename: download_file_if_absent(url, filename, search_directories)
 | 
			
		||||
             for url, filename in files_to_download.items()}
 | 
			
		||||
 | 
			
		||||
    return paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_tokenizer_training(imported_module, gpt2_files):
 | 
			
		||||
    BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None)
 | 
			
		||||
    download_file_if_absent = getattr(imported_module, "download_file_if_absent", None)
 | 
			
		||||
 | 
			
		||||
    tokenizer = BPETokenizerSimple()
 | 
			
		||||
    verdict_path = download_file_if_absent(
 | 
			
		||||
        url=(
 | 
			
		||||
            "https://raw.githubusercontent.com/rasbt/"
 | 
			
		||||
            "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
 | 
			
		||||
            "the-verdict.txt"
 | 
			
		||||
        ),
 | 
			
		||||
        filename="the-verdict.txt",
 | 
			
		||||
        search_dirs="."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with open(verdict_path, "r", encoding="utf-8") as f: # added ../01_main-chapter-code/
 | 
			
		||||
        text = f.read()
 | 
			
		||||
 | 
			
		||||
    tokenizer.train(text, vocab_size=1000, allowed_special={"<|endoftext|>"})
 | 
			
		||||
    assert len(tokenizer.vocab) == 1000, "Tokenizer vocabulary size mismatch."
 | 
			
		||||
    assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."
 | 
			
		||||
 | 
			
		||||
    input_text = "Jack embraced beauty through art and life."
 | 
			
		||||
    token_ids = tokenizer.encode(input_text)
 | 
			
		||||
    assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output."
 | 
			
		||||
 | 
			
		||||
    assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."
 | 
			
		||||
 | 
			
		||||
    tokenizer.save_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
 | 
			
		||||
    tokenizer2 = BPETokenizerSimple()
 | 
			
		||||
    tokenizer2.load_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
 | 
			
		||||
    assert tokenizer2.decode(token_ids) == input_text, "Decoded text mismatch after reloading tokenizer."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_gpt2_tokenizer_openai_simple(imported_module, gpt2_files):
 | 
			
		||||
    BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None)
 | 
			
		||||
 | 
			
		||||
    tokenizer_gpt2 = BPETokenizerSimple()
 | 
			
		||||
    tokenizer_gpt2.load_vocab_and_merges_from_openai(
 | 
			
		||||
        vocab_path=gpt2_files["encoder.json"], bpe_merges_path=gpt2_files["vocab.bpe"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert len(tokenizer_gpt2.vocab) == 50257, "GPT-2 tokenizer vocabulary size mismatch."
 | 
			
		||||
 | 
			
		||||
    input_text = "This is some text"
 | 
			
		||||
    token_ids = tokenizer_gpt2.encode(input_text)
 | 
			
		||||
    assert token_ids == [1212, 318, 617, 2420], "Tokenized output does not match expected GPT-2 encoding."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_gpt2_tokenizer_openai_edgecases(imported_module, gpt2_files):
 | 
			
		||||
    BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None)
 | 
			
		||||
 | 
			
		||||
    tokenizer_gpt2 = BPETokenizerSimple()
 | 
			
		||||
    tokenizer_gpt2.load_vocab_and_merges_from_openai(
 | 
			
		||||
        vocab_path=gpt2_files["encoder.json"], bpe_merges_path=gpt2_files["vocab.bpe"]
 | 
			
		||||
    )
 | 
			
		||||
    tik_tokenizer = tiktoken.get_encoding("gpt2")
 | 
			
		||||
 | 
			
		||||
    test_cases = [
 | 
			
		||||
        ("Hello,", [15496, 11]),
 | 
			
		||||
        ("Implementations", [3546, 26908, 602]),
 | 
			
		||||
        ("asdf asdfasdf a!!, @aba 9asdf90asdfk", [292, 7568, 355, 7568, 292, 7568, 257, 3228, 11, 2488, 15498, 860, 292, 7568, 3829, 292, 7568, 74]),
 | 
			
		||||
        ("Hello, world. Is this-- a test?", [15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30])
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    errors = []
 | 
			
		||||
 | 
			
		||||
    for input_text, expected_tokens in test_cases:
 | 
			
		||||
        tik_tokens = tik_tokenizer.encode(input_text)
 | 
			
		||||
        gpt2_tokens = tokenizer_gpt2.encode(input_text)
 | 
			
		||||
 | 
			
		||||
        print(f"Text: {input_text}")
 | 
			
		||||
        print(f"Expected Tokens: {expected_tokens}")
 | 
			
		||||
        print(f"tiktoken Output: {tik_tokens}")
 | 
			
		||||
        print(f"BPETokenizerSimple Output: {gpt2_tokens}")
 | 
			
		||||
        print("-" * 40)
 | 
			
		||||
 | 
			
		||||
        if tik_tokens != expected_tokens:
 | 
			
		||||
            errors.append(f"Tiktokenized output does not match expected GPT-2 encoding for '{input_text}'.\n"
 | 
			
		||||
                          f"Expected: {expected_tokens}, Got: {tik_tokens}")
 | 
			
		||||
 | 
			
		||||
        if gpt2_tokens != expected_tokens:
 | 
			
		||||
            errors.append(f"Tokenized output does not match expected GPT-2 encoding for '{input_text}'.\n"
 | 
			
		||||
                          f"Expected: {expected_tokens}, Got: {gpt2_tokens}")
 | 
			
		||||
 | 
			
		||||
    if errors:
 | 
			
		||||
        pytest.fail("\n".join(errors))
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user