FlaxDiff: A Diffusion library but with simple code and good explainations #4062
Unanswered
AshishKumar4
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Recently, to brush up my skills, I took up learning about diffusion based generative models after my failed attempts at ProGAN and StyleGAN reproductions (GANs are really hard to train :( ). I went on reading research papers and blogs but I got confused about the explanations and formulations presented in different ways in them, and it did take me quite a while to try reproduce up until the EDM paper by Karras. But I have mostly been a developer by trade, and so I developed all my code in a modular and segregated manner (at least that's what I think). I have compiled all of that into this repository, which is still a work in progress, but I did wanted to get some feedback on it and if its acceptable for the Community examples page in Flax (I know there is already a diffusion example there but bear with me).
Basically, I think that the Diffusers library, although very versatile and indeed very readable, in order to make everything general, is quite not-so-trivial to understand, and many other code examples (such as the community one) don't really implement much beyond DDIM/DDPM and LMS samplers and VP noise schedules.
I wanted to build a framework in flax for making it easy to extend and add new techniques aka be general, but still be clean and simple, and most importantly, easy to understand. But the library was first and for most a product of my hobby experimentations to please forgive my naivete.
FlaxDiff Repository
I have added a jupyter notebook as a tutorial on diffusion as well, talking about the various formulations, form the markov chain discrete formulations in DDPM/DDIM papers to the generalized SDE/ODE formulations, and how to implement them and train a simple Unet on them with flax:
Simple Diffusion Flax.ipynb
I used to be a keras and tensorflow 2 user long ago, and catching back to the field, realized that its basically ancient now, everything is mostly pytorch or flax. And I loved the concept of Jax and Flax and the performance, and hence why I choose to build it this way.
Disclaimers:
The 'library' term may be misleading, its not quite an installable library yet, more of a set of utils that work together. Maybe would fix the PyPi things on the weekend with enough time on hand.
This is really my first major contribution to open source; most other projects that I open sourced were either when I was a kid, or just for the sake of putting stuff out there. I never really invested much efforts in documenting or even making a proper README for that matter.
Beta Was this translation helpful? Give feedback.
All reactions