mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-30 09:00:09 +00:00
Use inference_device
This commit is contained in:
parent
fecfdd16ff
commit
b8e12e1dd1
@ -1519,14 +1519,19 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.to(\"cpu\")\n",
|
||||
"# NEW: use CPU here as inference is cheap with \n",
|
||||
"# this model and to ensure readers get same results in the\n",
|
||||
"# remaining sections of this book\n",
|
||||
"inference_device = torch.device(\"cpu\")\n",
|
||||
"\n",
|
||||
"model.to(inference_device)\n",
|
||||
"model.eval()\n",
|
||||
"\n",
|
||||
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||
"\n",
|
||||
"token_ids = generate_text_simple(\n",
|
||||
" model=model,\n",
|
||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
|
||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(inference_device),\n",
|
||||
" max_new_tokens=25,\n",
|
||||
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
|
||||
")\n",
|
||||
@ -2030,7 +2035,7 @@
|
||||
"\n",
|
||||
"token_ids = generate(\n",
|
||||
" model=model,\n",
|
||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
|
||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(inference_device),\n",
|
||||
" max_new_tokens=15,\n",
|
||||
" context_size=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" top_k=25,\n",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user