mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-03 03:10:21 +00:00
Merge pull request #68 from taihaozesong/fix_ch03_impl_wrapper
Fix mha wrapper implementations in ch03 bonus
This commit is contained in:
commit
319e919062
@ -37,9 +37,11 @@ class MultiHeadAttentionWrapper(nn.Module):
|
||||
[CausalAttention(d_in, d_out, block_size, dropout, qkv_bias)
|
||||
for _ in range(num_heads)]
|
||||
)
|
||||
self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat([head(x) for head in self.heads], dim=-1)
|
||||
context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
|
||||
return self.out_proj(context_vec)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -90,7 +90,7 @@
|
||||
"\n",
|
||||
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" d_out=embed_dim//12,\n",
|
||||
" block_size=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user