This repository contains self-study project on learning JAX.
- I have several model implementation and training in JAX under
flax_mnist.ipynb
in increasing order of complexity. - Moreover, an implementation of kmeans using gradient descent (on average inertia) in JAX is available in
jax_kmeans_gd.ipynb
.