Skip to content

Commit

Permalink
add flatten function to dealwith recusive stacking when taking gradi…
Browse files Browse the repository at this point in the history
…ents of parameters
  • Loading branch information
ahillsley committed Jan 11, 2024
1 parent 07bad70 commit 0db4b4d
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions blinx/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

from .fluorescence_model import p_norm


def inv_sigmoid(p):
return -jax.lax.log(1.0 / p - 1.0)
Expand Down Expand Up @@ -43,7 +45,9 @@ class Parameters:
"""

def __init__(self, r_e, r_bg, mu_ro, sigma_ro, gain, p_on, p_off, probs_are_logits=False):
def __init__(
self, r_e, r_bg, mu_ro, sigma_ro, gain, p_on, p_off, probs_are_logits=False
):
self.r_e = r_e
self.r_bg = r_bg
self.mu_ro = mu_ro
Expand Down Expand Up @@ -78,6 +82,28 @@ def reshape(self, shape):
probs_are_logits=True,
)

def flatten(self):
"""Convert this class into just a single tensor."""

return jnp.array(
[
Parameters._flatten_rec(self.r_e),
Parameters._flatten_rec(self.r_bg),
Parameters._flatten_rec(self.mu_ro),
Parameters._flatten_rec(self.sigma_ro),
Parameters._flatten_rec(self.gain),
Parameters._flatten_rec(self._p_on_logit),
Parameters._flatten_rec(self._p_off_logit),
]
)

@staticmethod
def _flatten_rec(parameters):
if isinstance(parameters, Parameters):
return parameters.flatten()
else:
return parameters

def __getitem__(self, key):
return Parameters(
self.r_e[key],
Expand All @@ -96,8 +122,8 @@ def __repr__(self):
f"r_bg={self.r_bg}\t"
f"μ_ro={self.mu_ro}\t"
f"o_ro={self.sigma_ro}\t"
f"p_on={self.p_on}\t"
f"p_off={self.p_off}\t"
# f"p_on={self.p_on}\t"
# f"p_off={self.p_off}\t"
f"gain={self.gain}\t"
f"p_on logits={self._p_on_logit}\t"
f"p_off logits={self._p_off_logit}"
Expand All @@ -119,7 +145,7 @@ def tree_flatten(self):
@classmethod
def tree_unflatten(cls, aux, children):
return cls(*children, probs_are_logits=True)

@classmethod
def stack(cls, parameters):
return Parameters(
Expand All @@ -130,6 +156,5 @@ def stack(cls, parameters):
jnp.stack([p.gain for p in parameters]),
jnp.stack([p._p_on_logit for p in parameters]),
jnp.stack([p._p_off_logit for p in parameters]),
probs_are_logits=True
probs_are_logits=True,
)

0 comments on commit 0db4b4d

Please sign in to comment.