mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-06 13:36:18 +00:00
address suggestions to improve clarity
This commit is contained in:
parent
42eda8b70f
commit
58d5bd9e39
File diff suppressed because one or more lines are too long
@ -16,7 +16,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
model_dir = os.path.join(models_dir, model_size)
|
||||
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||
filenames = [
|
||||
"checkpoint", "encoder.json", "settings.json",
|
||||
"checkpoint", "encoder.json", "hparams.json",
|
||||
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
||||
"model.ckpt.meta", "vocab.bpe"
|
||||
]
|
||||
@ -30,7 +30,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
|
||||
# Load settings and params
|
||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||
settings = json.load(open(os.path.join(model_dir, "settings.json")))
|
||||
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
||||
|
||||
return settings, params
|
||||
|
@ -37,7 +37,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
model_dir = os.path.join(models_dir, model_size)
|
||||
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||
filenames = [
|
||||
"checkpoint", "encoder.json", "settings.json",
|
||||
"checkpoint", "encoder.json", "hparams.json",
|
||||
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
||||
"model.ckpt.meta", "vocab.bpe"
|
||||
]
|
||||
@ -51,7 +51,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
|
||||
# Load settings and params
|
||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||
settings = json.load(open(os.path.join(model_dir, "settings.json")))
|
||||
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
||||
|
||||
return settings, params
|
||||
|
Loading…
x
Reference in New Issue
Block a user