From e700c66b7a0f78f8f94994b46336447aaf6c92d1 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Tue, 17 Jun 2025 20:09:53 +0200 Subject: [PATCH] removed old args in GQA class (#674) --- ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 7766dca..8ee68bc 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -452,10 +452,8 @@ "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", - " self, d_in, d_out, context_length, num_heads,\n", + " self, d_in, d_out, num_heads,\n", " num_kv_groups, # NEW\n", - " rope_base=10_000, # NEW\n", - " rope_config=None, # NEW\n", " dtype=None\n", " ):\n", " super().__init__()\n", @@ -645,10 +643,8 @@ "gqa = GroupedQueryAttention(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", - " context_length=max_context_len,\n", " num_heads=num_heads,\n", " num_kv_groups=8,\n", - " rope_base=llama_3_theta_base\n", ")\n", "\n", "gqa(example_batch)\n", @@ -753,11 +749,8 @@ " self.att = GroupedQueryAttention( # MultiHeadAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", - " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"], # NEW\n", - " rope_base=cfg[\"rope_base\"], # NEW\n", - " rope_config=cfg[\"rope_freq\"], # NEW\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n",