mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-19 06:02:38 +00:00
Add missing device transfer in gpt_generate.py (#436)
This commit is contained in:
parent
27a6a7e64a
commit
f61c008c5d
@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size):
|
||||
|
||||
token_ids = generate(
|
||||
model=gpt,
|
||||
idx=text_to_token_ids(input_prompt, tokenizer),
|
||||
idx=text_to_token_ids(input_prompt, tokenizer).to(device),
|
||||
max_new_tokens=25,
|
||||
context_size=gpt_config["context_length"],
|
||||
top_k=50,
|
||||
|
Loading…
x
Reference in New Issue
Block a user