update generate to match output in main chapter

This commit is contained in:
rasbt 2024-06-22 12:01:51 -05:00
parent 549e015548
commit 0026e6206b

View File

@ -266,13 +266,14 @@ def main(gpt_config, input_prompt, model_size):
gpt.eval()
tokenizer = tiktoken.get_encoding("gpt2")
torch.manual_seed(123)
token_ids = generate(
model=gpt,
idx=text_to_token_ids(input_prompt, tokenizer),
max_new_tokens=30,
max_new_tokens=25,
context_size=gpt_config["context_length"],
top_k=1,
top_k=50,
temperature=1.0
)
@ -284,7 +285,7 @@ if __name__ == "__main__":
torch.manual_seed(123)
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
INPUT_PROMPT = "Every effort moves you"
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size