mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
update generate to match output in main chapter
This commit is contained in:
parent
549e015548
commit
0026e6206b
@ -266,13 +266,14 @@ def main(gpt_config, input_prompt, model_size):
|
|||||||
gpt.eval()
|
gpt.eval()
|
||||||
|
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
token_ids = generate(
|
token_ids = generate(
|
||||||
model=gpt,
|
model=gpt,
|
||||||
idx=text_to_token_ids(input_prompt, tokenizer),
|
idx=text_to_token_ids(input_prompt, tokenizer),
|
||||||
max_new_tokens=30,
|
max_new_tokens=25,
|
||||||
context_size=gpt_config["context_length"],
|
context_size=gpt_config["context_length"],
|
||||||
top_k=1,
|
top_k=50,
|
||||||
temperature=1.0
|
temperature=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -284,7 +285,7 @@ if __name__ == "__main__":
|
|||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
|
||||||
CHOOSE_MODEL = "gpt2-small (124M)"
|
CHOOSE_MODEL = "gpt2-small (124M)"
|
||||||
INPUT_PROMPT = "Every effort moves"
|
INPUT_PROMPT = "Every effort moves you"
|
||||||
|
|
||||||
BASE_CONFIG = {
|
BASE_CONFIG = {
|
||||||
"vocab_size": 50257, # Vocabulary size
|
"vocab_size": 50257, # Vocabulary size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user