Sinusoidal representation networks (SIREN) can be used to parametrise any scalar or vector field
Make sure you have installed JAX and Equinox. Copy src.py
into your own project. Instances of SIREN
can be created like
import jax
from src import SIREN
siren = SIREN(
num_channels_in=2, # n (e.g. image grid)
num_channels_out=3, # m (e.g. RGB values)
num_layers=4,
num_latent_channels=1024,
omega=30, # angular frequency
rng_key=jax.random.PRNGKey(420)
)
For an example on how to train the SIREN in Equinox, look at main.py
. After installing Optax, scikit-image and tqdm, you can fit an image img.png
via
python main.py --path_to_image img.png --num_epochs 1000