fix(GatedDeltaNet): Init param A from log of a uniform distrib (#906)

This commit is contained in:
casinca 2025-11-09 21:22:52 +01:00 committed by GitHub
parent 35354fac80
commit 7d92267170
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -166,7 +166,8 @@ class GatedDeltaNet(nn.Module):
# A_log + W_alpha(x) + dt_bias
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
self.dt_bias = nn.Parameter(torch.ones(num_heads))
self.A_log = nn.Parameter(torch.zeros(num_heads))
A_init = torch.empty(num_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A_init))
# We could implement this as
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
# but the bias is separate for interpretability and