removing unused RoPE parameters (#590)

* removing unused RoPE parameters

* remove redundant context_length in GQA

---------

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
casinca 2025-04-01 00:10:39 +02:00 committed by GitHub
parent 222803737d
commit 152a087a37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,10 +233,8 @@
"source": [ "source": [
"class GroupedQueryAttention(nn.Module):\n", "class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n", " def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n", " self, d_in, d_out, num_heads,\n",
" num_kv_groups,\n", " num_kv_groups,\n",
" rope_base=10_000,\n",
" rope_config=None,\n",
" dtype=None\n", " dtype=None\n",
" ):\n", " ):\n",
" super().__init__()\n", " super().__init__()\n",
@ -322,11 +320,8 @@
" self.att = GroupedQueryAttention(\n", " self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n", " d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n",
" context_length=cfg[\"context_length\"],\n",
" num_heads=cfg[\"n_heads\"],\n", " num_heads=cfg[\"n_heads\"],\n",
" num_kv_groups=cfg[\"n_kv_groups\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n",
" rope_base=cfg[\"rope_base\"],\n",
" rope_config=cfg[\"rope_freq\"],\n",
" dtype=cfg[\"dtype\"]\n", " dtype=cfg[\"dtype\"]\n",
" )\n", " )\n",
" self.ff = FeedForward(cfg)\n", " self.ff = FeedForward(cfg)\n",