topk comment

This commit is contained in:
Sebastian Raschka 2024-08-20 20:44:15 -05:00
parent f4e45a3f40
commit 092b5b5429

View File

@ -1844,13 +1844,32 @@
"source": [
"new_logits = torch.where(\n",
" condition=next_token_logits < top_logits[-1],\n",
" input=torch.tensor(float('-inf')), \n",
" input=torch.tensor(float(\"-inf\")), \n",
" other=next_token_logits\n",
")\n",
"\n",
"print(new_logits)"
]
},
{
"cell_type": "markdown",
"id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00",
"metadata": {},
"source": [
"> NOTE: \n",
">\n",
"> An alternative, slightly more efficient implementation of the previous code cell is the following:\n",
">\n",
"> ```python\n",
"> new_logits = torch.full_like( # create tensor containing -inf values\n",
"> next_token_logits, -torch.inf\n",
">) \n",
"> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n",
"> ```\n",
"> <br>\n",
"> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
@ -1908,7 +1927,7 @@
" # Keep only top_k values\n",
" top_logits, _ = torch.topk(logits, top_k)\n",
" min_val = top_logits[:, -1]\n",
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
" logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n",
"\n",
" # New: Apply temperature scaling\n",
" if temperature > 0.0:\n",
@ -2485,7 +2504,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.12.2"
}
},
"nbformat": 4,