mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-26 00:24:36 +00:00
Improve BPE vocabulary saving and pair frequency handling (#539)
This commit is contained in:
parent
58aabe7dd8
commit
af4b73ca7b
@ -629,7 +629,7 @@
|
|||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
" # Save vocabulary\n",
|
" # Save vocabulary\n",
|
||||||
" with open(vocab_path, \"w\", encoding=\"utf-8\") as file:\n",
|
" with open(vocab_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||||||
" json.dump({k: v for k, v in self.vocab.items()}, file, ensure_ascii=False, indent=2)\n",
|
" json.dump(self.vocab, file, ensure_ascii=False, indent=2)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Save BPE merges as a list of dictionaries\n",
|
" # Save BPE merges as a list of dictionaries\n",
|
||||||
" with open(bpe_merges_path, \"w\", encoding=\"utf-8\") as file:\n",
|
" with open(bpe_merges_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||||||
@ -667,6 +667,9 @@
|
|||||||
" def find_freq_pair(token_ids, mode=\"most\"):\n",
|
" def find_freq_pair(token_ids, mode=\"most\"):\n",
|
||||||
" pairs = Counter(zip(token_ids, token_ids[1:]))\n",
|
" pairs = Counter(zip(token_ids, token_ids[1:]))\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" if not pairs:\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
" if mode == \"most\":\n",
|
" if mode == \"most\":\n",
|
||||||
" return max(pairs.items(), key=lambda x: x[1])[0]\n",
|
" return max(pairs.items(), key=lambda x: x[1])[0]\n",
|
||||||
" elif mode == \"least\":\n",
|
" elif mode == \"least\":\n",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user