Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrong tensor arrange for conv #1

Open
inspirit opened this issue Jun 14, 2022 · 9 comments
Open

wrong tensor arrange for conv #1

inspirit opened this issue Jun 14, 2022 · 9 comments

Comments

@inspirit
Copy link

Hey!
looks like there is a mess in tensors arrangements in few place here:

q, k, v = rearrange_many((q, k, v), 'b (h d) n -> b h n d', h = self.heads)

does not make sense to go to 'b h n d' last dim is channels dim

than next

projs = rearrange_many(projs.split(self.heads // self.groups, dim = 1), 'b h n d -> (b h) n d')

we split heads into groups and merge with batch and try to apply convolution that expects channels as second dim which is seq-len in our case now.

there is also

ds_convs.append(CausalDepthwiseConv1d(inner_dim, kernel_size))

setup of convolution layer to expect full inner_dim as input channels somehow

@inspirit
Copy link
Author

grouped_sims never used actually :)

grouped_sims = torch.cat(grouped_sims, dim = 1)

also Alibi impl here has an issue with backward call complaining about missing/overwritten gradient, after replacing it with another impl from x-transformers repo it works fine.

causal d-convs should work on dim_head instead of inner_dim or may be on (heads//groups) * dim_head its not clear from the paper, and if we refer to Primer impl they seem to run d-convs after projecting to q-k-v but before splitting to heads...

@lucidrains
Copy link
Owner

@inspirit oh hey! haha, this repository was far from ready, but thanks for giving it an early review

d3944ef do you want to see if most (all?) the issues were addressed?

@lucidrains
Copy link
Owner

i wasn't able to replicate the ALiBi error, so if you have a short script i can run, i can definitely fix that as well!

@inspirit
Copy link
Author

inspirit commented Jun 19, 2022

It seems the issue with ALiBi is due to few network forward calls before backward call, specifically i have

p0 = net(x)
with torch.no_grad():
    p = net(y)
loss = ....
loss.backward()

@inspirit
Copy link
Author

inspirit commented Jun 19, 2022

mask needs to be moved to correct device here:

causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)

i also think its good to have it optional in case we dont need causual masking?

@lucidrains
Copy link
Owner

mask needs to be moved to correct device here:

causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)

i also think its good to have it optional in case we dont need causual masking?

sure! done in the latest!

so i'm still not seeing the ALiBi error

below is what i'm running

import torch
from tranception_pytorch import Tranception

model = Tranception(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

amino_acids = torch.randint(0, 21, (1, 512))

logits = model(amino_acids)

with torch.no_grad():
    _ = model(amino_acids)
    _ = model(amino_acids)
    _ = model(amino_acids)

logits.sum().backward()

@inspirit
Copy link
Author

inspirit commented Jun 19, 2022

i noticed that you apply alibi here inside forward:

return qk_sim + self.bias[..., :i, :j]

however the interface of function assume it will not apply bias but will return it instead for manual addition in attention class:

and here we follow this idea:

grouped_sims = [(alibi(sim_group) + sim_group) for alibi, sim_group in zip(self.learned_alibi_pos_biases, grouped_sims)]

manually adding returned bias

Update: yup, that seems to be an issue, need to return bias or dont accumulate it inside attention class

@lucidrains
Copy link
Owner

@inspirit thank you Eugene! bc50024

@lucidrains
Copy link
Owner

@inspirit the repository is still missing the retrieved MSA's contribution to the prediction but i'll get to that later next week!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants