Use inference_device

This commit is contained in:
rasbt 2025-10-09 10:59:17 -05:00
parent fecfdd16ff
commit b8e12e1dd1
No known key found for this signature in database

View File

@ -1519,14 +1519,19 @@
}
],
"source": [
"model.to(\"cpu\")\n",
"# NEW: use CPU here as inference is cheap with \n",
"# this model and to ensure readers get same results in the\n",
"# remaining sections of this book\n",
"inference_device = torch.device(\"cpu\")\n",
"\n",
"model.to(inference_device)\n",
"model.eval()\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
"token_ids = generate_text_simple(\n",
" model=model,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(inference_device),\n",
" max_new_tokens=25,\n",
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
")\n",
@ -2030,7 +2035,7 @@
"\n",
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(inference_device),\n",
" max_new_tokens=15,\n",
" context_size=GPT_CONFIG_124M[\"context_length\"],\n",
" top_k=25,\n",