diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index cd90777..7a0a3f2 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1852,7 +1852,7 @@ "metadata": {}, "outputs": [], "source": [ - "def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):\n", + "def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):\n", "\n", " # For-loop is the same as before: Get logits, and only focus on last time step\n", " for _ in range(max_new_tokens):\n", @@ -1882,6 +1882,9 @@ " else:\n", " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n", "\n", + " if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n", + " break\n", + "\n", " # Same as before: append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n", "\n", @@ -2372,7 +2375,7 @@ "\n", "token_ids = generate(\n", " model=gpt,\n", - " idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n", + " idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n", " max_new_tokens=25,\n", " context_size=NEW_CONFIG[\"context_length\"],\n", " top_k=50,\n", @@ -2439,7 +2442,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/ch05/01_main-chapter-code/gpt_generate.py b/ch05/01_main-chapter-code/gpt_generate.py index d7abaf0..d302719 100644 --- a/ch05/01_main-chapter-code/gpt_generate.py +++ b/ch05/01_main-chapter-code/gpt_generate.py @@ -215,7 +215,7 @@ def load_weights_into_gpt(gpt, params): gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) -def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): +def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None): # For-loop is the same as before: Get logits, and only focus on last time step for _ in range(max_new_tokens): @@ -245,6 +245,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) + if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified + break + # Same as before: append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) diff --git a/ch05/02_alternative_weight_loading/previous_chapters.py b/ch05/02_alternative_weight_loading/previous_chapters.py index 19cc2c2..0c792ba 100644 --- a/ch05/02_alternative_weight_loading/previous_chapters.py +++ b/ch05/02_alternative_weight_loading/previous_chapters.py @@ -254,7 +254,7 @@ def token_ids_to_text(token_ids, tokenizer): return tokenizer.decode(flat.tolist()) -def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): +def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None): # For-loop is the same as before: Get logits, and only focus on last time step for _ in range(max_new_tokens): @@ -284,6 +284,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) + if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified + break + # Same as before: append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) diff --git a/ch06/02_bonus_additional-experiments/previous_chapters.py b/ch06/02_bonus_additional-experiments/previous_chapters.py index 8d6f827..7644d68 100644 --- a/ch06/02_bonus_additional-experiments/previous_chapters.py +++ b/ch06/02_bonus_additional-experiments/previous_chapters.py @@ -310,7 +310,7 @@ def load_weights_into_gpt(gpt, params): gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) -def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): +def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None): # For-loop is the same as before: Get logits, and only focus on last time step for _ in range(max_new_tokens): idx_cond = idx[:, -context_size:] @@ -339,6 +339,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) + if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified + break + # Same as before: append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)