-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
wrong tensor arrange for conv #1
Comments
grouped_sims never used actually :)
also Alibi impl here has an issue with backward call complaining about missing/overwritten gradient, after replacing it with another impl from x-transformers repo it works fine. causal d-convs should work on |
i wasn't able to replicate the ALiBi error, so if you have a short script i can run, i can definitely fix that as well! |
It seems the issue with ALiBi is due to few network forward calls before backward call, specifically i have
|
mask needs to be moved to correct device here:
i also think its good to have it optional in case we dont need causual masking? |
sure! done in the latest! so i'm still not seeing the ALiBi error below is what i'm running import torch
from tranception_pytorch import Tranception
model = Tranception(
dim = 512,
depth = 6,
heads = 8,
dim_head = 64
)
amino_acids = torch.randint(0, 21, (1, 512))
logits = model(amino_acids)
with torch.no_grad():
_ = model(amino_acids)
_ = model(amino_acids)
_ = model(amino_acids)
logits.sum().backward() |
i noticed that you apply alibi here inside forward:
however the interface of function assume it will not apply bias but will return it instead for manual addition in attention class:
and here we follow this idea:
manually adding returned bias Update: yup, that seems to be an issue, need to return bias or dont accumulate it inside attention class |
@inspirit the repository is still missing the retrieved MSA's contribution to the prediction but i'll get to that later next week! |
Hey!
looks like there is a mess in tensors arrangements in few place here:
tranception-pytorch/tranception_pytorch/tranception_pytorch.py
Line 140 in 610ebf2
does not make sense to go to 'b h n d' last dim is channels dim
than next
tranception-pytorch/tranception_pytorch/tranception_pytorch.py
Line 148 in 610ebf2
we split heads into groups and merge with batch and try to apply convolution that expects channels as second dim which is seq-len in our case now.
there is also
tranception-pytorch/tranception_pytorch/tranception_pytorch.py
Line 122 in 610ebf2
setup of convolution layer to expect full inner_dim as input channels somehow
The text was updated successfully, but these errors were encountered: