Actually, this is my one-evening attempt to get more handy with jax
and flax
on the basis of torch
implementation on the example of Mamba[1]. It looks more like a somewhat detailed interface of this model that also requires training and inference code. I hope this code will help you become more confident with
jax, flax or state-space models[2].
Feel free to contact me on any mistakes you find :)
I have also tried to implement associative scan in the jax
folder but probably it contains mistakes.
This repo is based on the following ones: annotated-mamba, mamba-minimal in torch, the official implementation
[1] - Gu, Dao et al. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
[2] Gu et al. (2022). Efficiently Modeling Long Sequences with Structured State Spaces