LLMs-from-scratch/ch04/07_moe/memory_estimator_moe.py

128 lines
4.1 KiB
Python
Raw Normal View History

2025-10-19 22:17:59 -05:00
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import argparse
DTYPE_BYTES = {
"fp32": 4,
"bf16": 2,
"fp16": 2,
"fp8": 1,
"int8": 1,
}
def bytes_convert(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def get_num_param_matrices(ffn_type):
if ffn_type == "gelu":
return 2
elif ffn_type == "swiglu":
return 3
else:
raise ValueError("--ffn_type must be 'gelu' or 'swiglu'")
def ffn_params(emb_dim, hidden_dim, ffn_type):
return get_num_param_matrices(ffn_type) * emb_dim * hidden_dim
def router_params(emb_dim, num_experts):
return emb_dim * num_experts
def estimate_params_and_hidden(
emb_dim, hidden_dim, ffn_type, num_experts, match_dense=False
):
P_dense = ffn_params(emb_dim, hidden_dim, ffn_type)
R = router_params(emb_dim, num_experts)
if match_dense:
num_param_matrices = get_num_param_matrices(ffn_type)
num = P_dense - R
den = num_experts * num_param_matrices * emb_dim
if num <= 0:
raise ValueError("Dense layer too small for requested num_experts.")
moe_hidden_dim = int(round(num / float(den)))
else:
moe_hidden_dim = hidden_dim
per_expert_params = ffn_params(emb_dim, moe_hidden_dim, ffn_type)
moe_total = num_experts * per_expert_params + R
return {
"dense_params": P_dense,
"router": R,
"moe_hidden_dim": moe_hidden_dim,
"per_expert_params": per_expert_params,
"moe_total": moe_total,
}
def main():
p = argparse.ArgumentParser(
description="Estimate FFN vs MoE parameter memory"
)
p.add_argument("--emb_dim", type=int, required=True,
help="Model embedding dimension.")
p.add_argument("--hidden_dim", type=int, required=True,
help="Dense FFN intermediate size (hidden dimension).")
p.add_argument("--ffn_type", choices=["gelu", "swiglu"], default="swiglu")
p.add_argument("--num_experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="bf16")
p.add_argument(
"--match_dense",
action="store_true",
help=("Auto-set per-expert hidden so MoE total params ~= dense FFN params "
"(router included)."),
)
args = p.parse_args()
bytes_per_elem = DTYPE_BYTES[args.dtype]
res = estimate_params_and_hidden(
emb_dim=args.emb_dim,
hidden_dim=args.hidden_dim,
ffn_type=args.ffn_type,
num_experts=args.num_experts,
match_dense=args.match_dense,
)
moe_active_params_per_token = (
res["router"] + args.top_k * res["per_expert_params"]
)
print("==== Config ====")
print(f"{'emb_dim':23}: {args.emb_dim}")
print(f"{'hidden_dim':23}: {args.hidden_dim}")
print(f"{'ffn_type':23}: {args.ffn_type}")
print(f"{'num_experts':23}: {args.num_experts}")
print(f"{'top_k':23}: {args.top_k}")
print(f"{'dtype':23}: {args.dtype} ({bytes_per_elem} Bytes/elem)")
print(f"{'match_dense':23}: {args.match_dense}")
print()
print("==== Model weights (parameters) ====")
print(f"{'Dense FFN params':23}: {res['dense_params']:,} "
f"({bytes_convert(res['dense_params'] * bytes_per_elem)})")
print(f"{'Per-expert params':23}: {res['per_expert_params']:,} "
f"({bytes_convert(res['per_expert_params'] * bytes_per_elem)})")
print(f"{'Router params':23}: {res['router']:,} "
f"({bytes_convert(res['router'] * bytes_per_elem)})")
print(f"{'MoE TOTAL params':23}: {res['moe_total']:,} "
f"({bytes_convert(res['moe_total'] * bytes_per_elem)})")
print(f"{'MoE ACTIVE/Token':23}: {moe_active_params_per_token:,} "
f"({bytes_convert(moe_active_params_per_token * bytes_per_elem)})")
print(f"{'moe_hidden_dim':23}: {res['moe_hidden_dim']}")
print()
if __name__ == "__main__":
main()