mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-19 14:12:40 +00:00
Add missing device transfer in gpt_generate.py (#436)
This commit is contained in:
parent
27a6a7e64a
commit
f61c008c5d
@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size):
|
|||||||
|
|
||||||
token_ids = generate(
|
token_ids = generate(
|
||||||
model=gpt,
|
model=gpt,
|
||||||
idx=text_to_token_ids(input_prompt, tokenizer),
|
idx=text_to_token_ids(input_prompt, tokenizer).to(device),
|
||||||
max_new_tokens=25,
|
max_new_tokens=25,
|
||||||
context_size=gpt_config["context_length"],
|
context_size=gpt_config["context_length"],
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user