A very basic implementtation of Vi-T
paper using Flax
neural network framework. The main goal of this one is to learn the device-agnostic framework, not get the best results. All results are collected in wandb.sweep
using a small custom logger wrapper.
Architecture of the model is only suitable for classification tasks
- Used
Adam
optimizer with cosine schedule of rate learning and gradient clipping; - Used MultiHead self-attention with
n = 8
heads and hidden dimension of768
; - Implemented learnable and sinusoid positional embeddings but used the former;
- https://huggingface.co/flax-community/vit-gpt2/tree/main/vit_gpt2
- https://github.com/google/flax/blob/main/examples/imagenet/train.py
- Official implementation
- Good set of jax tutorials