mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-26 16:52:04 +00:00
Merge pull request #164 from rasbt/eos_id-token
Add eos_id option for ch07
This commit is contained in:
commit
e8212c3f7c
@ -1852,7 +1852,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):\n",
|
"def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # For-loop is the same as before: Get logits, and only focus on last time step\n",
|
" # For-loop is the same as before: Get logits, and only focus on last time step\n",
|
||||||
" for _ in range(max_new_tokens):\n",
|
" for _ in range(max_new_tokens):\n",
|
||||||
@ -1882,6 +1882,9 @@
|
|||||||
" else:\n",
|
" else:\n",
|
||||||
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
|
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
" # Same as before: append sampled index to the running sequence\n",
|
" # Same as before: append sampled index to the running sequence\n",
|
||||||
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
|
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -2372,7 +2375,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"token_ids = generate(\n",
|
"token_ids = generate(\n",
|
||||||
" model=gpt,\n",
|
" model=gpt,\n",
|
||||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
|
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n",
|
||||||
" max_new_tokens=25,\n",
|
" max_new_tokens=25,\n",
|
||||||
" context_size=NEW_CONFIG[\"context_length\"],\n",
|
" context_size=NEW_CONFIG[\"context_length\"],\n",
|
||||||
" top_k=50,\n",
|
" top_k=50,\n",
|
||||||
@ -2439,7 +2442,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -215,7 +215,7 @@ def load_weights_into_gpt(gpt, params):
|
|||||||
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
||||||
|
|
||||||
|
|
||||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
|
||||||
|
|
||||||
# For-loop is the same as before: Get logits, and only focus on last time step
|
# For-loop is the same as before: Get logits, and only focus on last time step
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
@ -245,6 +245,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
|||||||
else:
|
else:
|
||||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
||||||
|
|
||||||
|
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
|
||||||
|
break
|
||||||
|
|
||||||
# Same as before: append sampled index to the running sequence
|
# Same as before: append sampled index to the running sequence
|
||||||
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ def token_ids_to_text(token_ids, tokenizer):
|
|||||||
return tokenizer.decode(flat.tolist())
|
return tokenizer.decode(flat.tolist())
|
||||||
|
|
||||||
|
|
||||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
|
||||||
|
|
||||||
# For-loop is the same as before: Get logits, and only focus on last time step
|
# For-loop is the same as before: Get logits, and only focus on last time step
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
@ -284,6 +284,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
|||||||
else:
|
else:
|
||||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
||||||
|
|
||||||
|
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
|
||||||
|
break
|
||||||
|
|
||||||
# Same as before: append sampled index to the running sequence
|
# Same as before: append sampled index to the running sequence
|
||||||
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
||||||
|
|
||||||
|
@ -310,7 +310,7 @@ def load_weights_into_gpt(gpt, params):
|
|||||||
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
||||||
|
|
||||||
|
|
||||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
|
||||||
# For-loop is the same as before: Get logits, and only focus on last time step
|
# For-loop is the same as before: Get logits, and only focus on last time step
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
idx_cond = idx[:, -context_size:]
|
idx_cond = idx[:, -context_size:]
|
||||||
@ -339,6 +339,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
|||||||
else:
|
else:
|
||||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
||||||
|
|
||||||
|
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
|
||||||
|
break
|
||||||
|
|
||||||
# Same as before: append sampled index to the running sequence
|
# Same as before: append sampled index to the running sequence
|
||||||
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user