update generate to match output in main chapter

This commit is contained in:
rasbt 2024-06-22 12:01:51 -05:00
parent 549e015548
commit 0026e6206b

View File

@ -266,13 +266,14 @@ def main(gpt_config, input_prompt, model_size):
gpt.eval() gpt.eval()
tokenizer = tiktoken.get_encoding("gpt2") tokenizer = tiktoken.get_encoding("gpt2")
torch.manual_seed(123)
token_ids = generate( token_ids = generate(
model=gpt, model=gpt,
idx=text_to_token_ids(input_prompt, tokenizer), idx=text_to_token_ids(input_prompt, tokenizer),
max_new_tokens=30, max_new_tokens=25,
context_size=gpt_config["context_length"], context_size=gpt_config["context_length"],
top_k=1, top_k=50,
temperature=1.0 temperature=1.0
) )
@ -284,7 +285,7 @@ if __name__ == "__main__":
torch.manual_seed(123) torch.manual_seed(123)
CHOOSE_MODEL = "gpt2-small (124M)" CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves" INPUT_PROMPT = "Every effort moves you"
BASE_CONFIG = { BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size "vocab_size": 50257, # Vocabulary size