From a86c118dec9183de995ab93458357c2052a38757 Mon Sep 17 00:00:00 2001 From: Gert-Jan Both <32122273+GJBoth@users.noreply.github.com> Date: Tue, 2 May 2023 14:23:04 +0200 Subject: [PATCH] Overhaul selection of trainable variables. We now have a simple Trainable class which just wraps any value, and comes from trainable so the API is the same. Also overhauls the internal structure - we can now register non-trainable parameters which are callables too (for example, non-traineable phasemask based on field shape). --- docs/training.ipynb | 120 +++++++++++++++------- src/chromatix/elements/amplitude_masks.py | 8 +- src/chromatix/elements/lenses.py | 44 +++----- src/chromatix/elements/phase_masks.py | 46 +++------ src/chromatix/elements/propagation.py | 34 +++--- src/chromatix/elements/sources.py | 119 +++++++-------------- src/chromatix/elements/utils.py | 39 +++++++ src/chromatix/utils/utils.py | 18 ++-- 8 files changed, 219 insertions(+), 209 deletions(-) create mode 100644 src/chromatix/elements/utils.py diff --git a/docs/training.ipynb b/docs/training.ipynb index 811164c..2523a2f 100644 --- a/docs/training.ipynb +++ b/docs/training.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -9,7 +8,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -26,6 +24,7 @@ "import jax.numpy as jnp\n", "from jax import random\n", "import flax.linen as nn\n", + "from flax.core import freeze, unfreeze\n", "import numpy as np\n", "\n", "import optax\n", @@ -41,7 +40,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -52,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -76,38 +74,87 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "We initialise the model, and simulate the data using some coefficients. We then define a loss function, which should return a (loss, metrics) pair:" + "When initialising the model, we get a dictionary consisting of both trainable parameters and a so-called state. The state contains all things we want calculated once and want to cache. Here it's just some of the other parameters, but it can also be a more complicated phasemask or a propagator." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FrozenDict({\n", + " state: {\n", + " ObjectivePointSource_0: {\n", + " _f: 100,\n", + " _n: 1.33,\n", + " _NA: 0.8,\n", + " _power: 10000000.0,\n", + " _amplitude: 1.0,\n", + " },\n", + " FFLens_0: {\n", + " _f: 100,\n", + " _n: 1.33,\n", + " _NA: None,\n", + " },\n", + " },\n", + " params: {\n", + " ZernikeAberrations_0: {\n", + " _coefficients: Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + " },\n", + "})\n" + ] + } + ], "source": [ - "# Instantiating model\n", "model = ZernikePSF()\n", + "variables = model.init(key)\n", + "print(variables)\n", + "\n", + "# Split into two\n", + "params, state = variables[\"params\"], variables[\"state\"]\n", + "del variables # delete for memory" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We make some synthetic data data using some coefficients. Note that the loss function has two inputs\n", "\n", + " We then define a loss function, which should return a (loss, metrics) pair:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ "# Specify \"ground truth\" parameters for Zernike coefficients\n", "coefficients_truth = jnp.array([2.0, 5.0, 3.0, 0, 1, 0, 1, 0, 1, 0])\n", - "params_true = jax.tree_map(lambda x: coefficients_truth, model.init(key)) # easiest to just do a tree_map\n", + "params_true = unfreeze(params)\n", + "params_true[\"ZernikeAberrations_0\"][\"_coefficients\"] = coefficients_truth\n", + "params_true = freeze(params_true)\n", "\n", "# Generating data\n", - "data = model.apply(params_true).intensity.squeeze()\n", + "data = model.apply({\"params\": params_true, \"state\": state}).intensity.squeeze()\n", "\n", "# Our loss function\n", - "def loss_fn(params, data):\n", - " psf_estimate = model.apply(params).intensity.squeeze()\n", + "def loss_fn(params, state, data):\n", + " psf_estimate = model.apply({\"params\": params, \"state\": state}).intensity.squeeze()\n", " loss = jnp.mean((psf_estimate - data)**2) / jnp.mean(data**2)\n", " return loss, {\"loss\": loss}" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -115,7 +162,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -124,13 +170,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# Setting the state which has the model, params and optimiser\n", - "state = TrainState.create(apply_fn=model.apply, \n", - " params=model.init(key), \n", + "trainstate = TrainState.create(apply_fn=model.apply, \n", + " params=params, \n", " tx=optax.adam(learning_rate=0.5))\n", "\n", "# Defining the function which returns the gradients\n", @@ -139,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -151,8 +197,8 @@ "200 {'loss': Array(0.02074304, dtype=float32)}\n", "300 {'loss': Array(6.8586576e-07, dtype=float32)}\n", "400 {'loss': Array(2.7676396e-11, dtype=float32)}\n", - "CPU times: user 4.22 s, sys: 292 ms, total: 4.51 s\n", - "Wall time: 3.02 s\n" + "CPU times: user 3.98 s, sys: 231 ms, total: 4.21 s\n", + "Wall time: 2.73 s\n" ] } ], @@ -161,8 +207,8 @@ "# Simple training loop\n", "max_iterations = 500\n", "for iteration in range(max_iterations):\n", - " grads, metrics = grad_fn(state.params, data) \n", - " state = state.apply_gradients(grads=grads)\n", + " grads, metrics = grad_fn(trainstate.params, state, data) \n", + " trainstate = trainstate.apply_gradients(grads=grads)\n", "\n", " if iteration % 100 == 0:\n", " print(iteration, metrics)" @@ -170,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -183,12 +229,11 @@ } ], "source": [ - "print(f\"Learned coefficients: {jnp.abs(jnp.around(state.params['params']['ZernikeAberrations_0']['zernike_coefficients'], 2))}\")\n", + "print(f\"Learned coefficients: {jnp.abs(jnp.around(trainstate.params['ZernikeAberrations_0']['_coefficients'], 2))}\")\n", "print(f\"True Coefficients: {coefficients_truth}\")" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -196,7 +241,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -205,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -213,12 +257,12 @@ "solver = jaxopt.LBFGS(loss_fn, has_aux=True)\n", "\n", "# Running solver\n", - "res = solver.run(model.init(key), data)" + "res = solver.run(model.init(key)[\"params\"], state, data)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -231,14 +275,21 @@ } ], "source": [ - "print(f\"Learned coefficients: {jnp.abs(jnp.around(res.params['params']['ZernikeAberrations_0']['zernike_coefficients'], 2))}\")\n", + "print(f\"Learned coefficients: {jnp.abs(jnp.around(res.params['ZernikeAberrations_0']['_coefficients'], 2))}\")\n", "print(f\"True Coefficients: {coefficients_truth}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -253,9 +304,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/src/chromatix/elements/amplitude_masks.py b/src/chromatix/elements/amplitude_masks.py index 348eb42..908128c 100644 --- a/src/chromatix/elements/amplitude_masks.py +++ b/src/chromatix/elements/amplitude_masks.py @@ -6,6 +6,7 @@ from ..functional.amplitude_masks import amplitude_change from ..utils import _broadcast_2d_to_spatial from ..ops import binarize +from chromatix.elements.utils import register __all__ = ["AmplitudeMask"] @@ -31,12 +32,7 @@ class AmplitudeMask(nn.Module): @nn.compact def __call__(self, field: Field) -> Field: """Applies ``amplitude`` mask to incoming ``Field``.""" - amplitude = ( - self.param("amplitude_pixels", self.amplitude, field.spatial_shape) - if callable(self.amplitude) - else self.amplitude - ) - + amplitude = register(self, "amplitude", field.spatial_shape) assert_rank( amplitude, 2, custom_message="Amplitude must be array of shape (H W)" ) diff --git a/src/chromatix/elements/lenses.py b/src/chromatix/elements/lenses.py index 20cb1fb..55a2013 100644 --- a/src/chromatix/elements/lenses.py +++ b/src/chromatix/elements/lenses.py @@ -3,6 +3,7 @@ from chex import PRNGKey from ..field import Field from .. import functional as cf +from chromatix.elements.utils import register __all__ = ["ThinLens", "FFLens", "DFLens"] @@ -29,15 +30,12 @@ class ThinLens(nn.Module): n: Union[float, Callable[[PRNGKey], float]] NA: Optional[Union[float, Callable[[PRNGKey], float]]] = None - def setup(self): - self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f - self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n - self._NA = ( - self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA - ) - + @nn.compact def __call__(self, field: Field) -> Field: - return cf.thin_lens(field, self._f, self._n, self._NA) + f = register(self, "f") + n = register(self, "n") + NA = register(self, "NA") + return cf.thin_lens(field, f, n, NA) class FFLens(nn.Module): @@ -64,15 +62,12 @@ class FFLens(nn.Module): NA: Optional[Union[float, Callable[[PRNGKey], float]]] = None inverse: bool = False - def setup(self): - self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f - self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n - self._NA = ( - self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA - ) - + @nn.compact def __call__(self, field: Field) -> Field: - return cf.ff_lens(field, self._f, self._n, self._NA, inverse=self.inverse) + f = register(self, "f") + n = register(self, "n") + NA = register(self, "NA") + return cf.ff_lens(field, f, n, NA, inverse=self.inverse) class DFLens(nn.Module): @@ -101,15 +96,10 @@ class DFLens(nn.Module): NA: Optional[Union[float, Callable[[PRNGKey], float]]] = None inverse: bool = False - def setup(self): - self._d = self.param("_d", self.d) if isinstance(self.d, Callable) else self.d - self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f - self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n - self._NA = ( - self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA - ) - + @nn.compact def __call__(self, field: Field) -> Field: - return cf.df_lens( - field, self._d, self._f, self._n, self._NA, inverse=self.inverse - ) + d = register(self, "d") + f = register(self, "f") + n = register(self, "n") + NA = register(self, "NA") + return cf.df_lens(field, d, f, n, NA, inverse=self.inverse) diff --git a/src/chromatix/elements/phase_masks.py b/src/chromatix/elements/phase_masks.py index a4ffe64..90b5df5 100644 --- a/src/chromatix/elements/phase_masks.py +++ b/src/chromatix/elements/phase_masks.py @@ -12,6 +12,7 @@ zernike_aberrations, ) from ..utils import _broadcast_2d_to_spatial +from chromatix.elements.utils import register __all__ = [ "PhaseMask", @@ -63,18 +64,8 @@ def __call__(self, field: Field) -> Field: pupil_args = (self.n, self.f, self.NA) else: pupil_args = () - phase = ( - self.param( - "phase_pixels", - self.phase, - field.spatial_shape, - field.dx[..., 0, 0].squeeze(), - field.spectrum[..., 0, 0].squeeze(), - *pupil_args, - ) - if callable(self.phase) - else self.phase - ) + + phase = register(self, "phase", field, *pupil_args) assert_rank(phase, 2, custom_message="Phase must be array of shape (H W)") phase = _broadcast_2d_to_spatial(phase, field.ndim) phase = spectrally_modulate_phase( @@ -142,17 +133,14 @@ def __call__(self, field: Field) -> Field: pupil_args = (self.n, self.f, self.NA) else: pupil_args = () - phase = ( - self.param( - "slm_pixels", - self.phase, - self.shape, - self.spacing, - field.spectrum[..., 0, 0].squeeze(), - *pupil_args, - ) - if callable(self.phase) - else self.phase + + phase = register( + self, + "phase", + self.shape, + self.spacing, + field.spectrum[..., 0, 0].squeeze(), + *pupil_args, ) assert_rank(phase, 2, custom_message="Phase must be array of shape (H W)") assert ( @@ -206,11 +194,7 @@ class SeidelAberrations(nn.Module): @nn.compact def __call__(self, field: Field) -> Field: """Applies ``phase`` mask to incoming ``Field``.""" - coefficients = ( - self.param("seidel_coefficients", self.coefficients) - if callable(self.coefficients) - else self.coefficients - ) + coefficients = register(self, "coefficients") phase = seidel_aberrations( field.spatial_shape, field.dx[..., 0, 0].squeeze(), @@ -260,11 +244,7 @@ class ZernikeAberrations(nn.Module): @nn.compact def __call__(self, field: Field) -> Field: """Applies ``phase`` mask to incoming ``Field``.""" - coefficients = ( - self.param("zernike_coefficients", self.coefficients) - if callable(self.coefficients) - else self.coefficients - ) + coefficients = register(self, "coefficients") phase = zernike_aberrations( field.spatial_shape, diff --git a/src/chromatix/elements/propagation.py b/src/chromatix/elements/propagation.py index 550f8dc..8281e12 100644 --- a/src/chromatix/elements/propagation.py +++ b/src/chromatix/elements/propagation.py @@ -14,6 +14,8 @@ compute_asm_propagator, ) from ..ops.field import pad, crop +from chromatix.elements.utils import register +from chromatix.utils import Trainable __all__ = ["Propagate"] @@ -84,15 +86,15 @@ class Propagate(nn.Module): @nn.compact def __call__(self, field: Field) -> Field: if self.cache_propagator and ( - isinstance(self.z, Callable) or isinstance(self.n, Callable) + isinstance(self.z, Trainable) or isinstance(self.n, Trainable) ): raise ValueError("Cannot cache propagation kernel if z or n are trainable.") if self.cache_propagator and self.method not in ["transfer", "exact", "asm"]: raise ValueError( "Can only cache kernel for 'transfer', 'exact', or 'asm' methods." ) - z = self.param("_z", self.z) if isinstance(self.z, Callable) else self.z - n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n + z = register(self, "z") + n = register(self, "n") if self.cache_propagator: field = pad(field, self.N_pad, cval=self.cval) propagator_args = ( @@ -103,19 +105,19 @@ def __call__(self, field: Field) -> Field: ) if self.method == "transfer": propagator = self.variable( - "propagation", + "state", "kernel", lambda: compute_transfer_propagator(*propagator_args), ) elif self.method == "exact": propagator = self.variable( - "propagation", + "state", "kernel", lambda: compute_exact_propagator(*propagator_args), ) elif self.method == "asm": propagator = self.variable( - "propagation", + "state", "kernel", lambda: compute_asm_propagator(*propagator_args), ) @@ -209,17 +211,15 @@ class KernelPropagate(nn.Module): @nn.compact def __call__(self, field: Field) -> Field: field = pad(field, self.N_pad, cval=self.cval) - if isinstance(self.propagator, Callable): - propagator = self.param( - "_propagator", - self.propagator, - field, - self.z, - self.n, - self.kykx, - ) - else: - propagator = self.propagator + propagator = register( + self, + "propagator", + field, + self.z, + self.n, + self.kykx, + ) + field = kernel_propagate(field, propagator) if self.mode == "same": field = crop(field, self.N_pad) diff --git a/src/chromatix/elements/sources.py b/src/chromatix/elements/sources.py index eb6d21c..40fd0f5 100644 --- a/src/chromatix/elements/sources.py +++ b/src/chromatix/elements/sources.py @@ -1,6 +1,6 @@ import jax.numpy as jnp import flax.linen as nn -from ..field import Field, ScalarField, VectorField +from ..field import Field from ..functional.sources import ( plane_wave, point_source, @@ -9,6 +9,7 @@ ) from typing import Optional, Callable, Tuple, Union from chex import PRNGKey, Array +from chromatix.elements.utils import register __all__ = ["PointSource", "ObjectivePointSource", "PlaneWave", "GenericField"] @@ -51,30 +52,21 @@ class PointSource(nn.Module): pupil: Optional[Callable[[Field], Field]] = None scalar: bool = True - def setup(self): - self._z = self.param("_z", self.z) if isinstance(self.z, Callable) else self.z - self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n - self._power = ( - self.param("_power", self.power) - if isinstance(self.power, Callable) - else self.power - ) - self._amplitude = ( - self.param("_amplitude", self.amplitude) - if isinstance(self.amplitude, Callable) - else self.amplitude - ) - + @nn.compact def __call__(self) -> Field: + power = register(self, "power") + z = register(self, "z") + n = register(self, "n") + amplitude = register(self, "amplitude") return point_source( self.shape, self.dx, self.spectrum, self.spectral_density, - self._z, - self._n, - self._power, - self._amplitude, + z, + n, + power, + amplitude, self.pupil, self.scalar, ) @@ -118,35 +110,25 @@ class ObjectivePointSource(nn.Module): amplitude: Union[float, Array, Callable[[PRNGKey], Array]] = 1.0 scalar: bool = True - def setup(self): - self._f = self.param("_f", self.f) if isinstance(self.f, Callable) else self.f - self._n = self.param("_n", self.n) if isinstance(self.n, Callable) else self.n - self._NA = ( - self.param("_NA", self.NA) if isinstance(self.NA, Callable) else self.NA - ) - self._power = ( - self.param("_power", self.power) - if isinstance(self.power, Callable) - else self.power - ) - self._amplitude = ( - self.param("_amplitude", self.amplitude) - if isinstance(self.amplitude, Callable) - else self.amplitude - ) - + @nn.compact def __call__(self, z: float) -> Field: + f = register(self, "f") + n = register(self, "n") + NA = register(self, "NA") + power = register(self, "power") + amplitude = register(self, "amplitude") + return objective_point_source( self.shape, self.dx, self.spectrum, self.spectral_density, z, - self._f, - self._n, - self._NA, - self._power, - self._amplitude, + f, + n, + NA, + power, + amplitude, self.scalar, ) @@ -189,32 +171,19 @@ class PlaneWave(nn.Module): pupil: Optional[Callable[[Field], Field]] = None scalar: bool = True - def setup(self): - self._kykx = ( - self.param("_kykx", self.kykx) - if isinstance(self.kykx, Callable) - else self.kykx - ) - self._power = ( - self.param("_power", self.power) - if isinstance(self.power, Callable) - else self.power - ) - self._amplitude = ( - self.param("_amplitude", self.amplitude) - if isinstance(self.amplitude, Callable) - else self.amplitude - ) - + @nn.compact def __call__(self) -> Field: + kykx = register(self, "kykx") + power = register(self, "power") + amplitude = register(self, "amplitude") return plane_wave( self.shape, self.dx, self.spectrum, self.spectral_density, - self._power, - self._amplitude, - self._kykx, + power, + amplitude, + kykx, self.pupil, self.scalar, ) @@ -252,31 +221,19 @@ class GenericField(nn.Module): pupil: Optional[Callable[[Field], Field]] = None scalar: bool = True - def setup(self): - self._amplitude = ( - self.param("_amplitude", self.amplitude) - if isinstance(self.amplitude, Callable) - else self.amplitude - ) - self._phase = ( - self.param("_phase", self.phase) - if isinstance(self.phase, Callable) - else self.phase - ) - self._power = ( - self.param("_power", self.power) - if isinstance(self.power, Callable) - else self.power - ) - + @nn.compact def __call__(self) -> Field: + amplitude = register(self, "amplitude") + phase = register(self, "phase") + power = register(self, "power") + return generic_field( self.dx, self.spectrum, self.spectral_density, - self._amplitude, - self._phase, - self._power, + amplitude, + phase, + power, self.pupil, self.scalar, ) diff --git a/src/chromatix/elements/utils.py b/src/chromatix/elements/utils.py new file mode 100644 index 0000000..cc1ce76 --- /dev/null +++ b/src/chromatix/elements/utils.py @@ -0,0 +1,39 @@ +from flax import linen as nn +from chromatix.utils import Trainable + + +def register( + module: nn.Module, + name: str, + *args, +): + """Registers the parameter `self.{name}` as a Flax parameter or variable depending + on whether the parameter is of type `Trainable`. Only used for internal ease-of-use. + + Name in Flax's parameterdict becomes `_{name}`, and if variable under collection + `state`. Supports initializing both with callables (*args are passed as + arguments) and fixed values. + + """ + try: + init = getattr(module, name) + except AttributeError: + print("Variable does not exist.") + + if isinstance(init, Trainable): + return module.param(f"_{name}", parse_init(init.val), *args) + else: + return module.variable( + "state", + f"_{name}", + parse_init(init), + None, + *args, + ).value + + +def parse_init(x): + def init(*args): + return x + + return x if callable(x) else init diff --git a/src/chromatix/utils/utils.py b/src/chromatix/utils/utils.py index aee929e..b4e8a38 100644 --- a/src/chromatix/utils/utils.py +++ b/src/chromatix/utils/utils.py @@ -3,9 +3,15 @@ from chex import Array, PRNGKey from einops import rearrange from typing import Any, Callable, Optional, Sequence, Tuple, Union +from dataclasses import dataclass -def trainable(x: Any) -> Callable: +@dataclass +class Trainable: + val: Any + + +def trainable(x: Any) -> Trainable: """ Returns a function with a valid signature for a Flax parameter initializer function (accepts a ``jax.random.PRNGKey`` as the first argument), which @@ -86,15 +92,7 @@ def trainable(x: Any) -> Callable: Returns: A function that takes a ``jax.random.PRNGKey`` as its first parameter. """ - - def init_fn(key: PRNGKey, *args, **kwargs) -> Any: - if callable(x): - y = x(*args, **kwargs) - else: - y = x - return y - - return init_fn + return Trainable(x) def next_order(val: int) -> int: