From 092b5b54290651aba93e7fb4bbae4ab3eadd727a Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 20 Aug 2024 20:44:15 -0500 Subject: [PATCH] topk comment --- ch05/01_main-chapter-code/ch05.ipynb | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index bb624e6..18c8bd9 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1844,13 +1844,32 @@ "source": [ "new_logits = torch.where(\n", " condition=next_token_logits < top_logits[-1],\n", - " input=torch.tensor(float('-inf')), \n", + " input=torch.tensor(float(\"-inf\")), \n", " other=next_token_logits\n", ")\n", "\n", "print(new_logits)" ] }, + { + "cell_type": "markdown", + "id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00", + "metadata": {}, + "source": [ + "> NOTE: \n", + ">\n", + "> An alternative, slightly more efficient implementation of the previous code cell is the following:\n", + ">\n", + "> ```python\n", + "> new_logits = torch.full_like( # create tensor containing -inf values\n", + "> next_token_logits, -torch.inf\n", + ">) \n", + "> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n", + "> ```\n", + ">
\n", + "> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n" + ] + }, { "cell_type": "code", "execution_count": 39, @@ -1908,7 +1927,7 @@ " # Keep only top_k values\n", " top_logits, _ = torch.topk(logits, top_k)\n", " min_val = top_logits[:, -1]\n", - " logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n", + " logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n", "\n", " # New: Apply temperature scaling\n", " if temperature > 0.0:\n", @@ -2485,7 +2504,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.12.2" } }, "nbformat": 4,