diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index 06931e7..9fc9ffd 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -1628,6 +1628,10 @@ "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape # New batch dimension b\n", + " # For inputs where `num_tokens` exceeds `context_length`, this will result in errors\n", + " # in the mask creation further below.\n", + " # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs \n", + " # do not exceed `context_length` before reaching this forward method. \n", " keys = self.W_key(x)\n", " queries = self.W_query(x)\n", " values = self.W_value(x)\n", @@ -1837,6 +1841,10 @@ "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n", + " # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, \n", + " # this will result in errors in the mask creation further below. \n", + " # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs \n", + " # do not exceed `context_length` before reaching this forwar\n", "\n", " keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n", " queries = self.W_query(x)\n", @@ -2029,7 +2037,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.16" } }, "nbformat": 4,