diff --git a/ch03/02_bonus_efficient-multihead-attention/ch03.py b/ch03/02_bonus_efficient-multihead-attention/ch03.py index 3be1cdb..46e4bb2 100644 --- a/ch03/02_bonus_efficient-multihead-attention/ch03.py +++ b/ch03/02_bonus_efficient-multihead-attention/ch03.py @@ -37,9 +37,11 @@ class MultiHeadAttentionWrapper(nn.Module): [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) for _ in range(num_heads)] ) + self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads) def forward(self, x): - return torch.cat([head(x) for head in self.heads], dim=-1) + context_vec = torch.cat([head(x) for head in self.heads], dim=-1) + return self.out_proj(context_vec) diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index 9692f32..b7b27df 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -90,7 +90,7 @@ "\n", "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", " d_in=embed_dim,\n", - " d_out=embed_dim,\n", + " d_out=embed_dim//12,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n",