address suggestions to improve clarity

This commit is contained in:
rasbt 2024-04-07 08:41:09 -05:00
parent 42eda8b70f
commit 58d5bd9e39
3 changed files with 93 additions and 56 deletions

File diff suppressed because one or more lines are too long

View File

@ -16,7 +16,7 @@ def download_and_load_gpt2(model_size, models_dir):
model_dir = os.path.join(models_dir, model_size)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
filenames = [
"checkpoint", "encoder.json", "settings.json",
"checkpoint", "encoder.json", "hparams.json",
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
"model.ckpt.meta", "vocab.bpe"
]
@ -30,7 +30,7 @@ def download_and_load_gpt2(model_size, models_dir):
# Load settings and params
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "settings.json")))
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
return settings, params

View File

@ -37,7 +37,7 @@ def download_and_load_gpt2(model_size, models_dir):
model_dir = os.path.join(models_dir, model_size)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
filenames = [
"checkpoint", "encoder.json", "settings.json",
"checkpoint", "encoder.json", "hparams.json",
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
"model.ckpt.meta", "vocab.bpe"
]
@ -51,7 +51,7 @@ def download_and_load_gpt2(model_size, models_dir):
# Load settings and params
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "settings.json")))
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
return settings, params