Fix mha wrapper implementations in ch03 bonus

This commit is contained in:
taihaozesong 2024-03-13 18:02:26 +08:00
parent 00b121a5af
commit f1fa9df15c
2 changed files with 4 additions and 2 deletions

View File

@ -37,9 +37,11 @@ class MultiHeadAttentionWrapper(nn.Module):
[CausalAttention(d_in, d_out, block_size, dropout, qkv_bias)
for _ in range(num_heads)]
)
self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
return self.out_proj(context_vec)

View File

@ -90,7 +90,7 @@
"\n",
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" d_out=embed_dim//12,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",