mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 19:10:19 +00:00
topk comment
This commit is contained in:
parent
f4e45a3f40
commit
092b5b5429
@ -1844,13 +1844,32 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"new_logits = torch.where(\n",
|
"new_logits = torch.where(\n",
|
||||||
" condition=next_token_logits < top_logits[-1],\n",
|
" condition=next_token_logits < top_logits[-1],\n",
|
||||||
" input=torch.tensor(float('-inf')), \n",
|
" input=torch.tensor(float(\"-inf\")), \n",
|
||||||
" other=next_token_logits\n",
|
" other=next_token_logits\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(new_logits)"
|
"print(new_logits)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"> NOTE: \n",
|
||||||
|
">\n",
|
||||||
|
"> An alternative, slightly more efficient implementation of the previous code cell is the following:\n",
|
||||||
|
">\n",
|
||||||
|
"> ```python\n",
|
||||||
|
"> new_logits = torch.full_like( # create tensor containing -inf values\n",
|
||||||
|
"> next_token_logits, -torch.inf\n",
|
||||||
|
">) \n",
|
||||||
|
"> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n",
|
||||||
|
"> ```\n",
|
||||||
|
"> <br>\n",
|
||||||
|
"> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 39,
|
"execution_count": 39,
|
||||||
@ -1908,7 +1927,7 @@
|
|||||||
" # Keep only top_k values\n",
|
" # Keep only top_k values\n",
|
||||||
" top_logits, _ = torch.topk(logits, top_k)\n",
|
" top_logits, _ = torch.topk(logits, top_k)\n",
|
||||||
" min_val = top_logits[:, -1]\n",
|
" min_val = top_logits[:, -1]\n",
|
||||||
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
|
" logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # New: Apply temperature scaling\n",
|
" # New: Apply temperature scaling\n",
|
||||||
" if temperature > 0.0:\n",
|
" if temperature > 0.0:\n",
|
||||||
@ -2485,7 +2504,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user