diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 47fa4457..cd76673e 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -767,11 +767,8 @@ def __init__(self, *, heads, dim_pairwise, window_size=None, num_memory_kv=0, ** # line 8 of Algorithm 24 - to_attn_bias_linear = LinearNoBias(dim_pairwise, heads) - nn.init.zeros_(to_attn_bias_linear.weight) - self.to_attn_bias_norm = nn.LayerNorm(dim_pairwise) - self.to_attn_bias = nn.Sequential(to_attn_bias_linear, Rearrange("b ... h -> b h ...")) + self.to_attn_bias = nn.Sequential(LinearNoBias(dim_pairwise, heads), Rearrange("b ... h -> b h ...")) @typecheck def forward(