This package implements SWAG as an Optax transform to allow usage with JAX.
Install from pip
as:
pip install optax-swag
To install the latest directly from source, run
pip install git+https://github.com/activatedgeek/optax-swag.git
To start updating the iterate statistics, use chaining as
import optax
from optax_swag import swag
optimizer = optax.chain(
... ## Other optimizer and transform config.
swag(freq, rank) ## Always add as the last transform.
)
The SWAGState object can be accessed from the optimizer state list for downstream usage.
A reference code to generate samples from the collected statistics is provided below.
import jax
import jax.numpy as jnp
from optax_swag import sample_swag
swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10
rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)
swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
jnp.array(samples_rng), swa_opt_state)
The resulting swag_sample_params
can now be used for downstream evaluation.
NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.
Apache 2.0