mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-06 13:36:18 +00:00
address suggestions to improve clarity
This commit is contained in:
parent
42eda8b70f
commit
58d5bd9e39
File diff suppressed because one or more lines are too long
@ -16,7 +16,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
|||||||
model_dir = os.path.join(models_dir, model_size)
|
model_dir = os.path.join(models_dir, model_size)
|
||||||
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||||
filenames = [
|
filenames = [
|
||||||
"checkpoint", "encoder.json", "settings.json",
|
"checkpoint", "encoder.json", "hparams.json",
|
||||||
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
||||||
"model.ckpt.meta", "vocab.bpe"
|
"model.ckpt.meta", "vocab.bpe"
|
||||||
]
|
]
|
||||||
@ -30,7 +30,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
|||||||
|
|
||||||
# Load settings and params
|
# Load settings and params
|
||||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||||
settings = json.load(open(os.path.join(model_dir, "settings.json")))
|
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
||||||
|
|
||||||
return settings, params
|
return settings, params
|
||||||
|
@ -37,7 +37,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
|||||||
model_dir = os.path.join(models_dir, model_size)
|
model_dir = os.path.join(models_dir, model_size)
|
||||||
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||||
filenames = [
|
filenames = [
|
||||||
"checkpoint", "encoder.json", "settings.json",
|
"checkpoint", "encoder.json", "hparams.json",
|
||||||
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
||||||
"model.ckpt.meta", "vocab.bpe"
|
"model.ckpt.meta", "vocab.bpe"
|
||||||
]
|
]
|
||||||
@ -51,7 +51,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
|||||||
|
|
||||||
# Load settings and params
|
# Load settings and params
|
||||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||||
settings = json.load(open(os.path.join(model_dir, "settings.json")))
|
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
||||||
|
|
||||||
return settings, params
|
return settings, params
|
||||||
|
Loading…
x
Reference in New Issue
Block a user