mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-28 18:40:01 +00:00
topk comment
This commit is contained in:
parent
f4e45a3f40
commit
092b5b5429
@ -1844,13 +1844,32 @@
|
||||
"source": [
|
||||
"new_logits = torch.where(\n",
|
||||
" condition=next_token_logits < top_logits[-1],\n",
|
||||
" input=torch.tensor(float('-inf')), \n",
|
||||
" input=torch.tensor(float(\"-inf\")), \n",
|
||||
" other=next_token_logits\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(new_logits)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"> NOTE: \n",
|
||||
">\n",
|
||||
"> An alternative, slightly more efficient implementation of the previous code cell is the following:\n",
|
||||
">\n",
|
||||
"> ```python\n",
|
||||
"> new_logits = torch.full_like( # create tensor containing -inf values\n",
|
||||
"> next_token_logits, -torch.inf\n",
|
||||
">) \n",
|
||||
"> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n",
|
||||
"> ```\n",
|
||||
"> <br>\n",
|
||||
"> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
@ -1908,7 +1927,7 @@
|
||||
" # Keep only top_k values\n",
|
||||
" top_logits, _ = torch.topk(logits, top_k)\n",
|
||||
" min_val = top_logits[:, -1]\n",
|
||||
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
|
||||
" logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n",
|
||||
"\n",
|
||||
" # New: Apply temperature scaling\n",
|
||||
" if temperature > 0.0:\n",
|
||||
@ -2485,7 +2504,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
x
Reference in New Issue
Block a user