Merge pull request #164 from rasbt/eos_id-token

Add eos_id option for ch07
This commit is contained in:
Sebastian Raschka 2024-05-18 16:10:25 -04:00 committed by GitHub
commit e8212c3f7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 6 deletions

View File

@ -1852,7 +1852,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):\n", "def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):\n",
"\n", "\n",
" # For-loop is the same as before: Get logits, and only focus on last time step\n", " # For-loop is the same as before: Get logits, and only focus on last time step\n",
" for _ in range(max_new_tokens):\n", " for _ in range(max_new_tokens):\n",
@ -1882,6 +1882,9 @@
" else:\n", " else:\n",
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n", " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
"\n", "\n",
" if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
" break\n",
"\n",
" # Same as before: append sampled index to the running sequence\n", " # Same as before: append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n", " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
"\n", "\n",
@ -2372,7 +2375,7 @@
"\n", "\n",
"token_ids = generate(\n", "token_ids = generate(\n",
" model=gpt,\n", " model=gpt,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n", " idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n",
" max_new_tokens=25,\n", " max_new_tokens=25,\n",
" context_size=NEW_CONFIG[\"context_length\"],\n", " context_size=NEW_CONFIG[\"context_length\"],\n",
" top_k=50,\n", " top_k=50,\n",
@ -2439,7 +2442,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.11.4"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -215,7 +215,7 @@ def load_weights_into_gpt(gpt, params):
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
# For-loop is the same as before: Get logits, and only focus on last time step # For-loop is the same as before: Get logits, and only focus on last time step
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
@ -245,6 +245,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
else: else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
break
# Same as before: append sampled index to the running sequence # Same as before: append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)

View File

@ -254,7 +254,7 @@ def token_ids_to_text(token_ids, tokenizer):
return tokenizer.decode(flat.tolist()) return tokenizer.decode(flat.tolist())
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
# For-loop is the same as before: Get logits, and only focus on last time step # For-loop is the same as before: Get logits, and only focus on last time step
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
@ -284,6 +284,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
else: else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
break
# Same as before: append sampled index to the running sequence # Same as before: append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)

View File

@ -310,7 +310,7 @@ def load_weights_into_gpt(gpt, params):
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
# For-loop is the same as before: Get logits, and only focus on last time step # For-loop is the same as before: Get logits, and only focus on last time step
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:] idx_cond = idx[:, -context_size:]
@ -339,6 +339,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
else: else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
break
# Same as before: append sampled index to the running sequence # Same as before: append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)