From 2e3205f747c5dfe38708e71ee4dc67a1ad135f53 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 1 Aug 2025 19:58:18 -0500 Subject: [PATCH] MoE Nb readability improvements (#761) --- .../standalone-qwen3-moe-plus-kvcache.ipynb | 35 +++++++++++++------ ch05/11_qwen3/standalone-qwen3-moe.ipynb | 35 +++++++++++++------ 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index aaaa302..fc11c0f 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -152,13 +152,28 @@ " self.num_experts = cfg[\"num_experts\"]\n", " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " meta_device = torch.device(\"meta\") # to reduce memory pressure and only load them when used (trades compute for memory)\n", - " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])])\n", + " # meta device to reduce memory pressure when initializing the model before loading weights\n", + " meta_device = torch.device(\"meta\")\n", + " self.fc1 = nn.ModuleList([\n", + " nn.Linear(\n", + " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", + " bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", + " for _ in range(cfg[\"num_experts\"])]\n", + " )\n", + " self.fc2 = nn.ModuleList([\n", + " nn.Linear(\n", + " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", + " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", + " )\n", + " for _ in range(cfg[\"num_experts\"])]\n", + " )\n", + " self.fc3 = nn.ModuleList([\n", + " nn.Linear(\n", + " cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n", + " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", + " )\n", + " for _ in range(cfg[\"num_experts\"])]\n", + " )\n", "\n", " def forward(self, x):\n", " b, seq_len, embed_dim = x.shape\n", @@ -194,20 +209,18 @@ " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", " # topk_probs = torch.softmax(topk_scores, dim=-1)\n", " # y = torch.zeros_like(x)\n", - "\n", + " #\n", " # for i in range(self.num_experts_per_tok):\n", " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n", " # expert_indices = topk_indices[..., i]\n", " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n", - "\n", + " #\n", " # # For each expert, process only the tokens assigned to it\n", " # for e in range(self.num_experts):\n", " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n", " # if mask.any():\n", " # selected = x[mask] # (num_tokens_e, emb_dim)\n", - " # # Compute FF for expert e\n", " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n", - " # # Scale by gating prob and scatter back\n", " # y[mask] += prob[mask] * out\n", " # return y" ] diff --git a/ch05/11_qwen3/standalone-qwen3-moe.ipynb b/ch05/11_qwen3/standalone-qwen3-moe.ipynb index a9cee82..76a66fc 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe.ipynb @@ -152,13 +152,28 @@ " self.num_experts = cfg[\"num_experts\"]\n", " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " meta_device = torch.device(\"meta\") # to reduce memory pressure and only load them when used (trades compute for memory)\n", - " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])])\n", - " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])])\n", + " # meta device to reduce memory pressure when initializing the model before loading weights\n", + " meta_device = torch.device(\"meta\")\n", + " self.fc1 = nn.ModuleList([\n", + " nn.Linear(\n", + " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", + " bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", + " for _ in range(cfg[\"num_experts\"])]\n", + " )\n", + " self.fc2 = nn.ModuleList([\n", + " nn.Linear(\n", + " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", + " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", + " )\n", + " for _ in range(cfg[\"num_experts\"])]\n", + " )\n", + " self.fc3 = nn.ModuleList([\n", + " nn.Linear(\n", + " cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n", + " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", + " )\n", + " for _ in range(cfg[\"num_experts\"])]\n", + " )\n", "\n", " def forward(self, x):\n", " b, seq_len, embed_dim = x.shape\n", @@ -194,20 +209,18 @@ " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", " # topk_probs = torch.softmax(topk_scores, dim=-1)\n", " # y = torch.zeros_like(x)\n", - "\n", + " #\n", " # for i in range(self.num_experts_per_tok):\n", " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n", " # expert_indices = topk_indices[..., i]\n", " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n", - "\n", + " #\n", " # # For each expert, process only the tokens assigned to it\n", " # for e in range(self.num_experts):\n", " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n", " # if mask.any():\n", " # selected = x[mask] # (num_tokens_e, emb_dim)\n", - " # # Compute FF for expert e\n", " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n", - " # # Scale by gating prob and scatter back\n", " # y[mask] += prob[mask] * out\n", " # return y" ]