Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: support return_mems in ContinuousTransformerWrapper #166

Open
pfeatherstone opened this issue Jul 17, 2023 · 16 comments
Open

Comments

@pfeatherstone
Copy link
Contributor

It would be great if ContinuousTransformerWrapper supported return_mems in the forward pass.
Thank you for the awesome repo!
Remarkably, it all works with torch.onnx.export()!

lucidrains added a commit that referenced this issue Jul 17, 2023
@lucidrains
Copy link
Owner

@pfeatherstone oh sure! threw it in there quickly before starting my main work

how are you using it? 🧐

@lucidrains
Copy link
Owner

lucidrains commented Jul 17, 2023

it is actually interesting how many people have told me they are using the continuous wrapper, although there's so little research on that. it works well?

@hugofloresgarcia
Copy link

We use a continuous transformer in our new paper: https://arxiv.org/pdf/2307.04686.pdf for music generation and find that it works well! we use the continuous representation of the VQ-VAE latents as the continuous embeddings used as input for the transformer. Awesome work w/ this repo btw @lucidrains!

@lucidrains
Copy link
Owner

@hugofloresgarcia congrats on the paper!

@pfeatherstone
Copy link
Contributor Author

pfeatherstone commented Jul 18, 2023

@pfeatherstone oh sure! threw it in there quickly before starting my main work

how are you using it? monocle_face

Oh it's just my inputs are in normalized floating point format already, not tokenized. I think Wav2Vec2 is basically like that no?

@pfeatherstone
Copy link
Contributor Author

@lucidrains How do you train a non-autoregressive continuous transformer with mem and return_mem.
I can see in the code you have XLAutoregressiveWrapper but that's only for auto-regressive transformers, i.e. where the targets are simply the inputs left-shifted. I can also see NonAutoregressiveWrapper. I can't quite tell if that's appropriate for training recurrent transformers. Thank you in advance.

@lucidrains
Copy link
Owner

lucidrains commented Jul 18, 2023

@pfeatherstone ohh, this repository is not well suited for custom recurrent

are you trying to do something like RMT, but non-autoregressive? does your idea resemble memformer?

@pfeatherstone
Copy link
Contributor Author

pfeatherstone commented Jul 18, 2023

Yes it's similar. To be honest I thought this repo would have done the job. Maybe I need to read up on this more to properly determine which architecture suits me best. To me the mechanism provided in this repo (return_mem and mems=), RMT and memformer all look like they are doing roughly the same thing...

Basically I want to output mem outputs from running segment (t) and feed them along with segment (t+1) to the next iteration, exactly like how return_mem works here. Difficulty is how do you train. Do you need to train with segments or not?
My architecture is using CTC loss.

@pfeatherstone
Copy link
Contributor Author

Basically i want a kind of stream-aware transformer with causal attention, non-autoregressive, trained with CTC loss, with an effective response that is infinite, a bit like how infinite impulse response (IIR) filters work. In transformer world, if you constantly feed mems from previous iterations, it should be able to "remember" information from the infinite past.

@lucidrains
Copy link
Owner

@pfeatherstone yea, i'm a big fan of the RMT architecture too

@lucidrains
Copy link
Owner

let me think, yea, i think x-transformers is close, since it has the ability to prepend embeddings (like PaLI)

i can take a look at this later this week

@pfeatherstone
Copy link
Contributor Author

pfeatherstone commented Jul 18, 2023

So I can see three candidate:

I'm new to recurrent transformers. I'm pretty familiar with "normal" transformers (GPT like for example), where you basically feed your entire input (text, image, or whatever). But though recurrence seems easy to design in the forward pass, I can't quite see how you train effectively (backward pass). do you need to ramdonly partition your input into segments during training and pretend you are "streaming" or is there a more elegant, less faffy, way of training

@lucidrains
Copy link
Owner

lucidrains commented Jul 18, 2023

@pfeatherstone i think your best bet is to modify the RMT architecture

i also included the memory-replay-backprop technique from memformer, so the network can learn to formulate its memories better with little cost to hardware memory

@pfeatherstone
Copy link
Contributor Author

@lucidrains There is also this paper https://arxiv.org/pdf/2109.00301.pdf, which proposes infinity-former
It's a bit daunting how many ways there is to cache and reuse hidden states over time. There must be a canonical way of doing this stuff....

@pfeatherstone
Copy link
Contributor Author

We just need IIR filters in neural networks...

@pfeatherstone
Copy link
Contributor Author

@lucidrains have you looked at RWKV architecture? Looks like it's solving something similar. Surely all these RNN+Transformer architectures are going to converge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants