Skip to content

Commit

Permalink
Overhaul selection of trainable variables.
Browse files Browse the repository at this point in the history
We now have a simple Trainable class which just wraps any value, and comes from trainable so the API is the same. Also overhauls the internal structure - we can now register non-trainable parameters which are callables too (for example, non-traineable phasemask based on field shape).
  • Loading branch information
GJBoth authored May 2, 2023
1 parent 5d691e0 commit a86c118
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 209 deletions.
120 changes: 85 additions & 35 deletions docs/training.ipynb
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training Chromatix models"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -26,6 +24,7 @@
"import jax.numpy as jnp\n",
"from jax import random\n",
"import flax.linen as nn\n",
"from flax.core import freeze, unfreeze\n",
"import numpy as np\n",
"\n",
"import optax\n",
Expand All @@ -41,7 +40,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -52,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -76,46 +74,94 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We initialise the model, and simulate the data using some coefficients. We then define a loss function, which should return a (loss, metrics) pair:"
"When initialising the model, we get a dictionary consisting of both trainable parameters and a so-called state. The state contains all things we want calculated once and want to cache. Here it's just some of the other parameters, but it can also be a more complicated phasemask or a propagator."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 23,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FrozenDict({\n",
" state: {\n",
" ObjectivePointSource_0: {\n",
" _f: 100,\n",
" _n: 1.33,\n",
" _NA: 0.8,\n",
" _power: 10000000.0,\n",
" _amplitude: 1.0,\n",
" },\n",
" FFLens_0: {\n",
" _f: 100,\n",
" _n: 1.33,\n",
" _NA: None,\n",
" },\n",
" },\n",
" params: {\n",
" ZernikeAberrations_0: {\n",
" _coefficients: Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" },\n",
" },\n",
"})\n"
]
}
],
"source": [
"# Instantiating model\n",
"model = ZernikePSF()\n",
"variables = model.init(key)\n",
"print(variables)\n",
"\n",
"# Split into two\n",
"params, state = variables[\"params\"], variables[\"state\"]\n",
"del variables # delete for memory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We make some synthetic data data using some coefficients. Note that the loss function has two inputs\n",
"\n",
" We then define a loss function, which should return a (loss, metrics) pair:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# Specify \"ground truth\" parameters for Zernike coefficients\n",
"coefficients_truth = jnp.array([2.0, 5.0, 3.0, 0, 1, 0, 1, 0, 1, 0])\n",
"params_true = jax.tree_map(lambda x: coefficients_truth, model.init(key)) # easiest to just do a tree_map\n",
"params_true = unfreeze(params)\n",
"params_true[\"ZernikeAberrations_0\"][\"_coefficients\"] = coefficients_truth\n",
"params_true = freeze(params_true)\n",
"\n",
"# Generating data\n",
"data = model.apply(params_true).intensity.squeeze()\n",
"data = model.apply({\"params\": params_true, \"state\": state}).intensity.squeeze()\n",
"\n",
"# Our loss function\n",
"def loss_fn(params, data):\n",
" psf_estimate = model.apply(params).intensity.squeeze()\n",
"def loss_fn(params, state, data):\n",
" psf_estimate = model.apply({\"params\": params, \"state\": state}).intensity.squeeze()\n",
" loss = jnp.mean((psf_estimate - data)**2) / jnp.mean(data**2)\n",
" return loss, {\"loss\": loss}"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training with Optax"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -124,13 +170,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# Setting the state which has the model, params and optimiser\n",
"state = TrainState.create(apply_fn=model.apply, \n",
" params=model.init(key), \n",
"trainstate = TrainState.create(apply_fn=model.apply, \n",
" params=params, \n",
" tx=optax.adam(learning_rate=0.5))\n",
"\n",
"# Defining the function which returns the gradients\n",
Expand All @@ -139,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 26,
"metadata": {},
"outputs": [
{
Expand All @@ -151,8 +197,8 @@
"200 {'loss': Array(0.02074304, dtype=float32)}\n",
"300 {'loss': Array(6.8586576e-07, dtype=float32)}\n",
"400 {'loss': Array(2.7676396e-11, dtype=float32)}\n",
"CPU times: user 4.22 s, sys: 292 ms, total: 4.51 s\n",
"Wall time: 3.02 s\n"
"CPU times: user 3.98 s, sys: 231 ms, total: 4.21 s\n",
"Wall time: 2.73 s\n"
]
}
],
Expand All @@ -161,16 +207,16 @@
"# Simple training loop\n",
"max_iterations = 500\n",
"for iteration in range(max_iterations):\n",
" grads, metrics = grad_fn(state.params, data) \n",
" state = state.apply_gradients(grads=grads)\n",
" grads, metrics = grad_fn(trainstate.params, state, data) \n",
" trainstate = trainstate.apply_gradients(grads=grads)\n",
"\n",
" if iteration % 100 == 0:\n",
" print(iteration, metrics)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 28,
"metadata": {},
"outputs": [
{
Expand All @@ -183,20 +229,18 @@
}
],
"source": [
"print(f\"Learned coefficients: {jnp.abs(jnp.around(state.params['params']['ZernikeAberrations_0']['zernike_coefficients'], 2))}\")\n",
"print(f\"Learned coefficients: {jnp.abs(jnp.around(trainstate.params['ZernikeAberrations_0']['_coefficients'], 2))}\")\n",
"print(f\"True Coefficients: {coefficients_truth}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training with Jaxopt"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -205,20 +249,20 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# Defining solver\n",
"solver = jaxopt.LBFGS(loss_fn, has_aux=True)\n",
"\n",
"# Running solver\n",
"res = solver.run(model.init(key), data)"
"res = solver.run(model.init(key)[\"params\"], state, data)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 31,
"metadata": {},
"outputs": [
{
Expand All @@ -231,14 +275,21 @@
}
],
"source": [
"print(f\"Learned coefficients: {jnp.abs(jnp.around(res.params['params']['ZernikeAberrations_0']['zernike_coefficients'], 2))}\")\n",
"print(f\"Learned coefficients: {jnp.abs(jnp.around(res.params['ZernikeAberrations_0']['_coefficients'], 2))}\")\n",
"print(f\"True Coefficients: {coefficients_truth}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -253,9 +304,8 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
8 changes: 2 additions & 6 deletions src/chromatix/elements/amplitude_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..functional.amplitude_masks import amplitude_change
from ..utils import _broadcast_2d_to_spatial
from ..ops import binarize
from chromatix.elements.utils import register

__all__ = ["AmplitudeMask"]

Expand All @@ -31,12 +32,7 @@ class AmplitudeMask(nn.Module):
@nn.compact
def __call__(self, field: Field) -> Field:
"""Applies ``amplitude`` mask to incoming ``Field``."""
amplitude = (
self.param("amplitude_pixels", self.amplitude, field.spatial_shape)
if callable(self.amplitude)
else self.amplitude
)

amplitude = register(self, "amplitude", field.spatial_shape)
assert_rank(
amplitude, 2, custom_message="Amplitude must be array of shape (H W)"
)
Expand Down
44 changes: 17 additions & 27 deletions src/chromatix/elements/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from chex import PRNGKey
from ..field import Field
from .. import functional as cf
from chromatix.elements.utils import register

__all__ = ["ThinLens", "FFLens", "DFLens"]

Expand All @@ -29,15 +30,12 @@ class ThinLens(nn.Module):
n: Union[float, Callable[[PRNGKey], float]]
NA: Optional[Union[float, Callable[[PRNGKey], float]]] = None

def setup(self):
self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f
self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n
self._NA = (
self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA
)

@nn.compact
def __call__(self, field: Field) -> Field:
return cf.thin_lens(field, self._f, self._n, self._NA)
f = register(self, "f")
n = register(self, "n")
NA = register(self, "NA")
return cf.thin_lens(field, f, n, NA)


class FFLens(nn.Module):
Expand All @@ -64,15 +62,12 @@ class FFLens(nn.Module):
NA: Optional[Union[float, Callable[[PRNGKey], float]]] = None
inverse: bool = False

def setup(self):
self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f
self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n
self._NA = (
self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA
)

@nn.compact
def __call__(self, field: Field) -> Field:
return cf.ff_lens(field, self._f, self._n, self._NA, inverse=self.inverse)
f = register(self, "f")
n = register(self, "n")
NA = register(self, "NA")
return cf.ff_lens(field, f, n, NA, inverse=self.inverse)


class DFLens(nn.Module):
Expand Down Expand Up @@ -101,15 +96,10 @@ class DFLens(nn.Module):
NA: Optional[Union[float, Callable[[PRNGKey], float]]] = None
inverse: bool = False

def setup(self):
self._d = self.param("_d", self.d) if isinstance(self.d, Callable) else self.d
self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f
self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n
self._NA = (
self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA
)

@nn.compact
def __call__(self, field: Field) -> Field:
return cf.df_lens(
field, self._d, self._f, self._n, self._NA, inverse=self.inverse
)
d = register(self, "d")
f = register(self, "f")
n = register(self, "n")
NA = register(self, "NA")
return cf.df_lens(field, d, f, n, NA, inverse=self.inverse)
Loading

0 comments on commit a86c118

Please sign in to comment.