mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-04 19:08:10 +00:00
fix(GatedDeltaNet): Init param A from log of a uniform distrib (#906)
This commit is contained in:
parent
35354fac80
commit
7d92267170
@ -166,7 +166,8 @@ class GatedDeltaNet(nn.Module):
|
||||
# A_log + W_alpha(x) + dt_bias
|
||||
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
|
||||
self.dt_bias = nn.Parameter(torch.ones(num_heads))
|
||||
self.A_log = nn.Parameter(torch.zeros(num_heads))
|
||||
A_init = torch.empty(num_heads).uniform_(0, 16)
|
||||
self.A_log = nn.Parameter(torch.log(A_init))
|
||||
# We could implement this as
|
||||
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
|
||||
# but the bias is separate for interpretability and
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user