From 3832e27e252276368775d997a594c99f69b2f69e Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:09:10 +0200 Subject: [PATCH 1/8] * update API with incremental * update to accept arbitrary dt --- essm_jax/essm.py | 300 +++++++++++++++++++++++++++++------------------ 1 file changed, 186 insertions(+), 114 deletions(-) diff --git a/essm_jax/essm.py b/essm_jax/essm.py index c0d2b1f..71fb768 100644 --- a/essm_jax/essm.py +++ b/essm_jax/essm.py @@ -36,6 +36,13 @@ class FilterResult(NamedTuple): observation_cov: jax.Array # [num_timesteps, observation_size, observation_size] The covariance of p(x[t] | x[:t-1]) +class IncrementalFilterState(NamedTuple): + t: jax.Array # the time index + log_cumulative_marginal_likelihood: jax.Array # log marginal likelihood prod_t p(x[t] | x[:t-1]) + filtered_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t]) + filtered_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t]) + + class SmoothingResult(NamedTuple): """ Represents the result of a backward smoothing pass. @@ -90,12 +97,14 @@ class ExtendedStateSpaceModel: Must be a MultivariateNormalLinearOperator. more_data_than_params: If True, the observation function has more outputs than inputs. materialise_jacobians: If True, the Jacobians are materialised as dense matrices. + dt: The time step size, default is 1. """ transition_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] observation_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] initial_state_prior: tfpd.MultivariateNormalLinearOperator more_data_than_params: bool = False materialise_jacobians: bool = False + dt: float = 1. def __post_init__(self): if not callable(self.transition_fn): @@ -165,7 +174,7 @@ def _sample_latents_op(latent, y): init = self.initial_state_prior.sample(seed=init_key) xs = ( jax.random.split(latent_key, num_time), - jnp.arange(1, num_time + 1) + t0 + jnp.arange(1, num_time + 1) * self.dt + t0 ) _, samples = lax.scan( _sample_latents_op, @@ -197,24 +206,34 @@ def _check_shapes(self, observations: jax.Array, mask: Optional[jax.Array] = Non raise ValueError('mask and observations must have the same length') def forward_simulate(self, key: jax.Array, num_time: int, - observations: jax.Array, mask: Optional[jax.Array] = None) -> SampleResult: + filter_result: Union[FilterResult | IncrementalFilterState]) -> SampleResult: """ - Simulate from the model, from the end of the forward filtering pass. + Simulate from the model, from the end of a forward filtering pass. Args: key: a PRNGKey num_time: the number of time steps to simulate - observations: [num_time, observation_size] array of observations - mask: [num_time] array of masks, True for missing observations + filter_result: the result of the forward filtering pass, or incremental filter update. Returns: sample result: num_time timesteps of forward simulated """ - filter_result = self.forward_filter(observations, mask) - initial_state_prior = tfpd.MultivariateNormalTriL( - loc=filter_result.filtered_mean[-1], - scale_tril=jnp.linalg.cholesky(_efficient_add_scalar_diag(filter_result.filtered_cov[-1], 1e-6)) - ) + if isinstance(filter_result, FilterResult): + initial_state_prior = tfpd.MultivariateNormalTriL( + loc=filter_result.filtered_mean[-1], + scale_tril=lax.linalg.cholesky(_efficient_add_scalar_diag(filter_result.filtered_cov[-1], 1e-6), + symmetrize_input=False) + ) + t0 = filter_result.t[-1] + elif isinstance(filter_result, IncrementalFilterState): + initial_state_prior = tfpd.MultivariateNormalTriL( + loc=filter_result.filtered_mean, + scale_tril=lax.linalg.cholesky(_efficient_add_scalar_diag(filter_result.filtered_cov, 1e-6), + symmetrize_input=False) + ) + t0 = filter_result.t + else: + raise ValueError('filter_result must be a FilterResult or IncrementalFilter instance.') new_essm = ExtendedStateSpaceModel( transition_fn=self.transition_fn, observation_fn=self.observation_fn, @@ -222,11 +241,142 @@ def forward_simulate(self, key: jax.Array, num_time: int, more_data_than_params=self.more_data_than_params, materialise_jacobians=self.materialise_jacobians ) - return new_essm.sample(key=key, num_time=num_time, t0=filter_result.t[-1]) + return new_essm.sample(key=key, num_time=num_time, t0=t0) + + def incremental_update(self, filter_state: IncrementalFilterState, observation: jax.Array, + mask: Optional[jax.Array] = None) -> Tuple[ + IncrementalFilterState, tfpd.MultivariateNormalLinearOperator]: + """ + Perform an incremental update of the filter state. Does not advance the time index. I.e. produces + + p(z[t] | x[:t]) from p(z[t] | x[:t-1]) and p(x[t] | x[:t-1]) + + Args: + filter_state: the current filter state + observation: [n] the observation at the current time + mask: scalar, the mask at the current time, True for missing observations + + Returns: + the updated filter state, and the marginal distribution of the observation p(x[t] | x[:t-1]) + """ + + Hop = self.get_observation_jacobian(t=filter_state.t, observation_size=observation.size) + H = Hop(filter_state.filtered_mean) + if self.materialise_jacobians: + H = H.to_dense() + + # Update step, compute p(z[t] | x[:t]) from p(z[t] | x[:t-1]) + observation_dist = self.observation_fn(filter_state.filtered_mean, filter_state.t) + R = observation_dist.covariance() + # Push-forward the prior (i.e. predictive) distribution to the observation space + x_expectation = observation_dist.mean() # [observation_size] + tmp_H_P = H @ filter_state.filtered_cov + S = tmp_H_P @ H.T + R # [observation_size, observation_size] + S_chol = lax.linalg.cholesky(S, symmetrize_input=False) # [observation_size, observation_size] + + marginal_dist = tfpd.MultivariateNormalTriL(x_expectation, S_chol) + # return_marginal = tfpd.MultivariateNormalFullCovariance(x_expectation, S) + + # Compute the log marginal likelihood p(x[t] | x[:t-1]) + log_marginal_likelihood = marginal_dist.log_prob(observation) + + # Compute the Kalman gain + # K = predict_cov @ H.T @ inv(S) = predict_cov.T @ H.T @ inv(S) + K = hpsd_solve(S, tmp_H_P, cholesky_matrix=S_chol).T # [latent_size, observation_size] + + # Update the state estimate + filtered_mean = filter_state.filtered_mean + K @ (observation - x_expectation) # [latent_size] + + # Update the state covariance using Joseph's form to ensure positive semi-definite + # tmp_factor = (I - K @ H) + if self.more_data_than_params: + tmp_factor = _efficient_add_scalar_diag((- K) @ H, 1.) + else: + tmp_factor = _efficient_add_scalar_diag(K @ (-H), 1.) + filtered_cov = tmp_factor @ filter_state.filtered_cov @ tmp_factor.T + K @ R @ K.T # [latent_size, latent_size] + + # When masked, then the filtered state is the predicted state. + if mask is not None: + filtered_mean = lax.select(*jnp.broadcast_arrays(mask, + filter_state.filtered_mean, filtered_mean)) + filtered_cov = lax.select(*jnp.broadcast_arrays(mask, + filter_state.filtered_cov, filtered_cov)) + log_marginal_likelihood = lax.select(*jnp.broadcast_arrays(mask, + jnp.zeros_like(log_marginal_likelihood), + log_marginal_likelihood)) + + log_cumulative_marginal_likelihood = filter_state.log_cumulative_marginal_likelihood + log_marginal_likelihood + return IncrementalFilterState( + t=filter_state.t, + log_cumulative_marginal_likelihood=log_cumulative_marginal_likelihood, + filtered_mean=filtered_mean, + filtered_cov=filtered_cov + ), marginal_dist + + def incremental_predict(self, filter_state: IncrementalFilterState) -> IncrementalFilterState: + """ + Perform an incremental prediction step of the filter state, advancing the time index. I.e. produces + + p(z[t+1] | x[:t]) from p(z[t] | x[:t]) + + Args: + filter_state: the current filter state + + Returns: + the predicted filter state, with time index advanced + """ + # Predict step, compute p(z[t+1] | x[:t]) + + t = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype) + + Fop = self.get_transition_jacobian(t=t) + F = Fop(filter_state.filtered_mean) + if self.materialise_jacobians: + F = F.to_dense() + + predicted_dist = self.transition_fn(filter_state.filtered_mean, t) + predicted_mean = predicted_dist.mean() # [latent_size] + Q = predicted_dist.covariance() # [latent_size, latent_size] + predicted_cov = F @ filter_state.filtered_cov @ F.T + Q # [latent_size, latent_size] + + return IncrementalFilterState( + t=t, + log_cumulative_marginal_likelihood=filter_state.log_cumulative_marginal_likelihood, + filtered_mean=predicted_mean, + filtered_cov=predicted_cov + ) + + def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> IncrementalFilterState: + """ + Create an incremental filter at the initial time. + + Args: + t0: the time of prior state (before the first observation) + + Returns: + the initial incremental filter state at the first possible observation time, i.e. t0+1. + """ + # Push forward initial state to create p(z[1] | z[0]) + t0 = jnp.asarray(t0, jnp.float32) + t1 = t0 + jnp.asarray(self.dt, t0.dtype) + init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) + + init_Fop = self.get_transition_jacobian(t1) + init_F = init_Fop(self.initial_state_prior.mean()) + if self.materialise_jacobians: + init_F = init_F.to_dense() + init_predicted_mean = init_predict_dist.mean() + init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() + return IncrementalFilterState( + t=t1, + log_cumulative_marginal_likelihood=jnp.asarray(0.), + filtered_mean=init_predicted_mean, + filtered_cov=init_predicted_cov + ) def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = None, marginal_likelihood_only: bool = False, - t0: Union[jax.Array, int] = 0) -> Union[FilterResult, jax.Array]: + t0: Union[jax.Array, float] = 0.) -> Union[FilterResult, jax.Array]: """ Run the forward filtering pass, computing the total marginal likelihood @@ -248,19 +398,13 @@ def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = No """ self._check_shapes(observations=observations, mask=mask) - class Carry(NamedTuple): - log_cumulative_marginal_likelihood: jax.Array # log marginal likelihood - predicted_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t]) - predicted_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t]) - class YType(NamedTuple): observation: jax.Array # [observation_size] observation at time t mask: jax.Array # mask at time t - t: jax.Array # time index num_time = np.shape(observations)[0] - def _filter_op(carry: Carry, y: YType) -> Tuple[Carry, FilterResult]: + def _filter_op(filter_state: IncrementalFilterState, y: YType) -> Tuple[IncrementalFilterState, FilterResult]: """ A single step of the forward equations. """ @@ -268,111 +412,38 @@ def _filter_op(carry: Carry, y: YType) -> Tuple[Carry, FilterResult]: # Note: We perform update FIRST, then predict which is contrary to the usual order. # This is so that the filter results naturally align with the smoothing operation. - Hop = self.get_observation_jacobian(t=y.t, observation_size=y.observation.size) - H = Hop(carry.predicted_mean) - if self.materialise_jacobians: - H = H.to_dense() - - # Update step, compute p(z[t] | x[:t]) from p(z[t] | x[:t-1]) - observation_dist = self.observation_fn(carry.predicted_mean, y.t) - R = observation_dist.covariance() - # Push-forward the prior (i.e. predictive) distribution to the observation space - x_expectation = observation_dist.mean() # [observation_size] - tmp_H_P = H @ carry.predicted_cov - S = tmp_H_P @ H.T + R # [observation_size, observation_size] - S_chol = jnp.linalg.cholesky(S) # [observation_size, observation_size] - - marginal_dist = tfpd.MultivariateNormalTriL(x_expectation, S_chol) - - # Compute the log marginal likelihood p(x[t] | x[:t-1]) - log_marginal_likelihood = marginal_dist.log_prob(y.observation) - - # Compute the Kalman gain - # K = predict_cov @ H.T @ inv(S) = predict_cov.T @ H.T @ inv(S) - K = hpsd_solve(S, tmp_H_P, cholesky_matrix=S_chol).T # [latent_size, observation_size] - - # Update the state estimate - filtered_mean = carry.predicted_mean + K @ (y.observation - x_expectation) # [latent_size] - - # Update the state covariance using Joseph's form to ensure positive semi-definite - # tmp_factor = (I - K @ H) - if self.more_data_than_params: - tmp_factor = _efficient_add_scalar_diag((- K) @ H, 1.) - else: - tmp_factor = _efficient_add_scalar_diag(K @ (-H), 1.) - filtered_cov = tmp_factor @ carry.predicted_cov @ tmp_factor.T + K @ R @ K.T # [latent_size, latent_size] - - # When masked, then the filtered state is the predicted state. - if mask is not None: - filtered_mean = lax.select(*jnp.broadcast_arrays(y.mask, - carry.predicted_mean, filtered_mean)) - filtered_cov = lax.select(*jnp.broadcast_arrays(y.mask, - carry.predicted_cov, filtered_cov)) - log_marginal_likelihood = lax.select(*jnp.broadcast_arrays(y.mask, - jnp.zeros_like(log_marginal_likelihood), - log_marginal_likelihood)) + updated_filter_state, marginal_dist = self.incremental_update( + filter_state=filter_state, + observation=y.observation, + mask=y.mask if mask is not None else None + ) # Predict step, compute p(z[t+1] | x[:t]) - - Fop = self.get_transition_jacobian(t=y.t + 1) - F = Fop(filtered_mean) - if self.materialise_jacobians: - F = F.to_dense() - - predicted_dist = self.transition_fn(filtered_mean, y.t + 1) - predicted_mean = predicted_dist.mean() # [latent_size] - Q = predicted_dist.covariance() # [latent_size, latent_size] - predicted_cov = F @ filtered_cov @ F.T + Q # [latent_size, latent_size] - - # Update cumulative marginal likelihood - log_cumulative_marginal_likelihood = carry.log_cumulative_marginal_likelihood + log_marginal_likelihood - - return Carry( - log_cumulative_marginal_likelihood=log_cumulative_marginal_likelihood, - predicted_mean=predicted_mean, - predicted_cov=predicted_cov - ), FilterResult( - t=y.t, - log_cumulative_marginal_likelihood=log_cumulative_marginal_likelihood, - filtered_mean=filtered_mean, - filtered_cov=filtered_cov, - predicted_mean=predicted_mean, - predicted_cov=predicted_cov, - observation_mean=x_expectation, - observation_cov=S + predicted_filter_state = self.incremental_predict(filter_state=updated_filter_state) + + return predicted_filter_state, FilterResult( + t=updated_filter_state.t, + log_cumulative_marginal_likelihood=updated_filter_state.log_cumulative_marginal_likelihood, + filtered_mean=updated_filter_state.filtered_mean, + filtered_cov=updated_filter_state.filtered_cov, + predicted_mean=predicted_filter_state.filtered_mean, + predicted_cov=predicted_filter_state.filtered_cov, + observation_mean=marginal_dist.mean(), + observation_cov=marginal_dist.covariance() ) - # Push forward initial state to create p(z[1] | z[0]) - t0 = jnp.asarray(t0, jnp.int32) - t1 = t0 + jnp.asarray(1, jnp.int32) - init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) - - init_Fop = self.get_transition_jacobian(t1) - init_F = init_Fop(self.initial_state_prior.mean()) - if self.materialise_jacobians: - init_F = init_F.to_dense() - init_predicted_mean = init_predict_dist.mean() - init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() - - # init_predicted_mean = self.initial_state_prior.mean() - # init_predicted_cov = self.initial_state_prior.covariance() - init_result = Carry( - log_cumulative_marginal_likelihood=jnp.asarray(0.), - predicted_mean=init_predicted_mean, - predicted_cov=init_predicted_cov - ) + filter_state = self.create_initial_filter_state(t0=t0) if mask is None: _mask = jnp.zeros(num_time, dtype=jnp.bool_) # dummy variable (we skip the mask select) else: _mask = mask xs = YType( observation=observations, - mask=_mask, - t=jnp.arange(1, num_time + 1, dtype=jnp.int32) + t0 + mask=_mask ) final_accumulate, filter_results = lax.scan( f=_filter_op, - init=init_result, + init=filter_state, xs=xs ) if marginal_likelihood_only: @@ -393,7 +464,7 @@ def log_prob(self, observations: jax.Array, mask: Optional[jax.Array] = None) -> return self.forward_filter(observations, mask, marginal_likelihood_only=True) def posterior_marginals(self, observations: jax.Array, mask: Optional[jax.Array] = None, - t0: Union[jax.Array, int] = 0) -> Union[ + t0: Union[jax.Array, float] = 0.) -> Union[ SmoothingResult, Tuple[SmoothingResult, InitialPrior]]: """ Compute the posterior marginal distributions of the latents, p(z[t] | x[:T]). @@ -443,7 +514,8 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: # Compute the RTS smoother gain # J = y.filtered_cov @ F.T @ jnp.linalg.inv(y.predicted_cov) - predicted_cov_chol = jnp.linalg.cholesky(y.predicted_cov) # Possibly need to add a small diagonal jitter + predicted_cov_chol = lax.linalg.cholesky(y.predicted_cov, + symmetrize_input=False) # Possibly need to add a small diagonal jitter tmp_F_P = F @ y.filtered_cov J = hpsd_solve(y.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T @@ -493,7 +565,7 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: init_predicted_mean = init_predict_dist.mean() init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() - t0 = t1 - jnp.asarray(1, jnp.int32) + t0 = t1 - jnp.asarray(self.dt, t1.dtype) y = FilterResult( t=t0, log_cumulative_marginal_likelihood=jnp.asarray(0.), From 3489c36274516eaea2a447391b371ffb3450caac Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:09:19 +0200 Subject: [PATCH 2/8] * pytree utils --- essm_jax/pytee_utils.py | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 essm_jax/pytee_utils.py diff --git a/essm_jax/pytee_utils.py b/essm_jax/pytee_utils.py new file mode 100644 index 0000000..71626d9 --- /dev/null +++ b/essm_jax/pytee_utils.py @@ -0,0 +1,43 @@ +from typing import Tuple, Callable, TypeVar + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax + +PT = TypeVar('PT') + + +def pytree_unravel(example_tree: PT) -> Tuple[Callable[[PT], jax.Array], Callable[[jax.Array], PT]]: + """ + Returns functions to ravel and unravel a pytree. + + Args: + example_tree: a pytree to be unravelled + + Returns: + ravel_fun: a function to ravel the pytree + unravel_fun: a function to unravel + """ + leaf_list, tree_def = jax.tree.flatten(example_tree) + + sizes = [np.size(leaf) for leaf in leaf_list] + shapes = [np.shape(leaf) for leaf in leaf_list] + dtypes = [leaf.dtype for leaf in leaf_list] + + def ravel_fun(pytree: PT) -> jax.Array: + leaf_list, tree_def = jax.tree.flatten(pytree) + # promote types to common one + common_dtype = jnp.result_type(*dtypes) + leaf_list = [leaf.astype(common_dtype) for leaf in leaf_list] + return jnp.concatenate([lax.reshape(leaf, (size,)) for leaf, size in zip(leaf_list, sizes)]) + + def unravel_fun(flat_array: jax.Array) -> PT: + leaf_list = [] + start = 0 + for size, shape, dtype in zip(sizes, shapes, dtypes): + leaf_list.append(lax.reshape(flat_array[start:start + size], shape).astype(dtype)) + start += size + return jax.tree.unflatten(tree_def, leaf_list) + + return ravel_fun, unravel_fun From 30d48223c9c140da53348c35307ec4b575f3fb45 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:11:34 +0200 Subject: [PATCH 3/8] * update change log --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f1fa846..e668498 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,7 @@ print(smooth_result) forward_samples = essm.forward_simulate( key=jax.random.PRNGKey(0), num_time=25, - observations=samples.observation, - mask=mask + filter_result=filter_result ) import pylab as plt @@ -86,9 +85,14 @@ plt.legend() plt.show() ``` +## Online Filtering + +Take a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application. + # Change Log 13 August 2024: Initial release 1.0.0. +14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt. ## Star History From cb0ac8146de6656d78a0b4048f2a292f838cc731 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:11:46 +0200 Subject: [PATCH 4/8] * add sparse utils --- essm_jax/sparse.py | 68 +++++++++++++++++ essm_jax/tests/test_essm.py | 133 ++++++++++++++++++++++++++++++++-- essm_jax/tests/test_sparse.py | 31 ++++++++ 3 files changed, 226 insertions(+), 6 deletions(-) create mode 100644 essm_jax/sparse.py create mode 100644 essm_jax/tests/test_sparse.py diff --git a/essm_jax/sparse.py b/essm_jax/sparse.py new file mode 100644 index 0000000..4b83001 --- /dev/null +++ b/essm_jax/sparse.py @@ -0,0 +1,68 @@ +from typing import Tuple, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np + + +class SparseRepresentation(NamedTuple): + shape: Tuple[int, ...] + rows: jax.Array + cols: jax.Array + vals: jax.Array + + +def create_sparse_rep(m: np.ndarray) -> SparseRepresentation: + """ + Creates a sparse rep from matrix m. Use in linear models with materialise_jacobian=False for 2x speed up. + + Args: + m: [N,M] matrix + + Returns: + sparse rep + """ + rows, cols = np.where(m) + sort_indices = np.lexsort((cols, rows)) + rows = rows[sort_indices] + cols = cols[sort_indices] + return SparseRepresentation( + shape=np.shape(m), + rows=jnp.asarray(rows), + cols=jnp.asarray(cols), + vals=jnp.asarray(m[rows, cols]) + ) + + +def to_dense(m: SparseRepresentation, out: jax.Array | None = None) -> jax.Array: + """ + Form dense matrix. + + Args: + m: sparse rep + out: output buffer + + Returns: + out + M + """ + if out is None: + out = jnp.zeros(m.shape, m.vals.dtype) + + return out.at[m.rows, m.cols].add(m.vals, unique_indices=True, indices_are_sorted=True) + + +def matvec_sparse(m: SparseRepresentation, v: jax.Array, out: jax.Array | None = None) -> jax.Array: + """ + Compute matmul for sparse rep. Speeds up large sparse linear models by about 2x. + + Args: + m: sparse rep + v: vec + out: output buffer to add to. + + Returns: + out + M @ v + """ + if out is None: + out = jnp.zeros(m.shape[0]) + return out.at[m.rows].add(m.vals * v[m.cols], unique_indices=True, indices_are_sorted=True) diff --git a/essm_jax/tests/test_essm.py b/essm_jax/tests/test_essm.py index c8d3dc2..da7f0cf 100644 --- a/essm_jax/tests/test_essm.py +++ b/essm_jax/tests/test_essm.py @@ -1,6 +1,11 @@ import time import jax +import pytest + +from essm_jax.sparse import create_sparse_rep, matvec_sparse + +jax.config.update('jax_enable_x64', True) import numpy as np import tensorflow_probability.substrates.jax as tfp from jax import numpy as jnp @@ -217,19 +222,21 @@ def observation_fn(z, t): ) sample = essm.sample(jax.random.PRNGKey(0), 1000) - filter_fn = jax.jit(lambda: essm.forward_filter(sample.observation)).lower().compile() - filter_jvp_fn = jax.jit(lambda: essm_jvp.forward_filter(sample.observation)).lower().compile() + filter_fn = jax.jit( + lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() + filter_jvp_fn = jax.jit( + lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() t0 = time.time() filter_results = filter_fn() - filter_results.t.block_until_ready() + filter_results.block_until_ready() t1 = time.time() dt1 = t1 - t0 print(f"Time for essm: {t1 - t0}") t0 = time.time() filter_results_jvp = filter_jvp_fn() - filter_results_jvp.t.block_until_ready() + filter_results_jvp.block_until_ready() t1 = time.time() dt2 = t1 - t0 print(f"Time for essm_jvp: {t1 - t0}") @@ -282,8 +289,7 @@ def observation_fn(z, t): forward_samples = essm.forward_simulate( key=jax.random.PRNGKey(0), num_time=25, - observations=samples.observation, - mask=mask + filter_result=filter_result ) try: @@ -313,3 +319,118 @@ def test__efficienct_add_scalar_diag(): A = jnp.eye(100) c = 1. assert jnp.all(_efficient_add_scalar_diag(A, c) == A + c * jnp.eye(100)) + + +def test_incremental_filtering(): + def transition_fn(z, t): + mean = z + z * jnp.sin(2 * jnp.pi * t / 10) + cov = 0.1 * jnp.eye(np.size(z)) + return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) + + def observation_fn(z, t): + mean = z + cov = t * 0.01 * jnp.eye(np.size(z)) + return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) + + n = 1 + + initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n)) + + essm = ExtendedStateSpaceModel( + transition_fn=transition_fn, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + materialise_jacobians=False, # Fast + more_data_than_params=False # if observation is bigger than latent we can speed it up. + ) + samples = essm.sample(jax.random.PRNGKey(0), 100) + + filter_result = essm.forward_filter(samples.observation) + + filter_state = essm.create_initial_filter_state() + + import pylab as plt + + for i in range(100): + filter_state, _ = essm.incremental_update(filter_state, samples.observation[i]) + plt.scatter(filter_state.t, filter_state.filtered_mean, c='black') + assert filter_state.t == filter_result.t[i] + filter_state = essm.incremental_predict(filter_state) + # print(i, np.abs( + # filter_state.log_cumulative_marginal_likelihood - filter_result.log_cumulative_marginal_likelihood[i])) + # print(i, np.max(np.abs(filter_state.filtered_mean - filter_result.predicted_mean[i]))) + # print(i, np.max(np.abs(filter_state.filtered_cov - filter_result.predicted_cov[i]))) + # print(i, np.max(np.abs(filter_state.filtered_cov))) + np.testing.assert_allclose(filter_state.log_cumulative_marginal_likelihood, + filter_result.log_cumulative_marginal_likelihood[i], atol=1e-5) + np.testing.assert_allclose(filter_state.filtered_mean, filter_result.predicted_mean[i], atol=1e-5) + np.testing.assert_allclose(filter_state.filtered_cov, filter_result.predicted_cov[i], atol=1e-5) + + plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent') + plt.legend() + plt.show() + + +@pytest.mark.parametrize('use_sparse', [False, True]) +def test_performance_sparse(use_sparse: bool): + # Show that using sparse rep speeds up linear system + n = 128 + k = 10 + m = np.zeros((n, n)) + rows = np.random.randint(n, size=k) + cols = np.random.randint(n, size=k) + m[rows, cols] += 1. + + if use_sparse: + m = create_sparse_rep(m) + else: + m = jnp.asarray(m) + + def transition_fn(z, t): + if use_sparse: + mean = matvec_sparse(m, z) + else: + mean = m @ z + scale = jnp.ones(np.size(z)) + return tfpd.MultivariateNormalDiag(mean, scale) + + def observation_fn(z, t): + mean = z + scale = jnp.ones(np.size(z)) + return tfpd.MultivariateNormalDiag(mean, scale) + + initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n)) + + essm = ExtendedStateSpaceModel( + transition_fn=transition_fn, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + materialise_jacobians=True + ) + + essm_jvp = ExtendedStateSpaceModel( + transition_fn=transition_fn, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + materialise_jacobians=False + ) + + sample = essm.sample(jax.random.PRNGKey(0), 512) + filter_fn = jax.jit( + lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() + filter_jvp_fn = jax.jit( + lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() + + t0 = time.time() + filter_results = filter_fn() + filter_results.block_until_ready() + t1 = time.time() + dt1 = t1 - t0 + print(f"Time for essm(use_sparse={use_sparse}): {t1 - t0}") + + t0 = time.time() + filter_results_jvp = filter_jvp_fn() + filter_results_jvp.block_until_ready() + t1 = time.time() + dt2 = t1 - t0 + print(f"Time for essm_jvp(use_sparse={use_sparse}): {t1 - t0}") diff --git a/essm_jax/tests/test_sparse.py b/essm_jax/tests/test_sparse.py new file mode 100644 index 0000000..32ef1c2 --- /dev/null +++ b/essm_jax/tests/test_sparse.py @@ -0,0 +1,31 @@ +import jax +import numpy as np +from jax import numpy as jnp + +jax.config.update('jax_enable_x64', True) + +from essm_jax.sparse import create_sparse_rep, matvec_sparse, to_dense + + +def test_sparse_rep(): + m = np.asarray([[1., 0, 0], + [-1., 2., 0.], + [0., 0., 5.]]) + rep = create_sparse_rep(m) + v = jnp.asarray([1, 1, 1]) + np.testing.assert_allclose(matvec_sparse(rep, v), m @ v) + + m = np.random.normal(size=(100, 100)) + v = np.random.normal(size=100) + + rep = create_sparse_rep(m) + np.testing.assert_allclose(matvec_sparse(rep, v), m @ v) + + +def test_to_dense(): + m = np.asarray([[1., 0, 0], + [-1., 2., 0.], + [0., 0., 5.]]) + rep = create_sparse_rep(m) + M = to_dense(rep) + np.testing.assert_allclose(M, m) From bf755c10d671612731c31ccea6f3f10498b6a908 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:36:05 +0300 Subject: [PATCH 5/8] * Add dt * Add examples --- README.md | 2 +- ...excitable_damped_harmonic_oscillator.ipynb | 317 ++++++++++++++++++ docs/examples/online_filtering.ipynb | 196 +++++++++++ essm_jax/essm.py | 168 +++++----- essm_jax/tests/test_essm.py | 44 +-- 5 files changed, 624 insertions(+), 103 deletions(-) create mode 100644 docs/examples/excitable_damped_harmonic_oscillator.ipynb create mode 100644 docs/examples/online_filtering.ipynb diff --git a/README.md b/README.md index e668498..af2d27e 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ from essm_jax.essm import ExtendedStateSpaceModel tfpd = tfp.distributions -def transition_fn(z, t): +def transition_fn(z, t, t_next): mean = z + jnp.sin(2 * jnp.pi * t / 10 * z) cov = 0.1 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) diff --git a/docs/examples/excitable_damped_harmonic_oscillator.ipynb b/docs/examples/excitable_damped_harmonic_oscillator.ipynb new file mode 100644 index 0000000..7188573 --- /dev/null +++ b/docs/examples/excitable_damped_harmonic_oscillator.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ef3eaff3797489e6", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:28:24.816871897Z", + "start_time": "2024-08-14T22:28:23.676057239Z" + } + }, + "outputs": [], + "source": [ + "import dataclasses\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import pylab as plt\n", + "import tensorflow_probability.substrates.jax as tfp\n", + "\n", + "from essm_jax.essm import ExtendedStateSpaceModel\n", + "\n", + "tfpd = tfp.distributions\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's define the dampled harmonic oscillator\n", + "\n", + "$$m \\ddot x= - c \\dot x - k x + f(t)$$\n", + "$$\\iff \\ddot x = - 2 \\zeta \\omega \\dot x - \\omega^2 x + \\frac{f(t)}{m}$$\n", + "\n", + "with $\\omega = \\frac{k}{m} \\in \\mathbb{R}^+$, $\\zeta = \\frac{c}{2 \\sqrt{mk}} \\in \\mathbb{R}^+$ and $f(t) \\sim \\mathcal{N}[0, \\sigma_f^2]$.\n", + "\n", + "Let us make this non-linear by supposing that $M = \\log m \\sim \\mathcal{N}[\\bar{M}, \\sigma_M^2]$.\n", + "\n", + "This defines a state,\n", + "\n", + "$$ z = [x, v, M]$$\n", + "\n", + "transition mean function,\n", + "\n", + "$$ z_{t}, t \\to [x + v \\Delta t, v + (-c v - k x) \\Delta t, \\bar M]$$\n", + "\n", + "transition noise scale function,\n", + "\n", + "$$ z_{t}, t \\to [0, \\sigma_f, \\sigma_M]$$\n", + "\n", + "observation mean function,\n", + "\n", + "$$ z_{t}, t \\to [x, 0, 0]$$\n", + "\n", + "and observation noise scale function,\n", + "\n", + "$$ z_{t}, t \\to [\\sigma_x, 0, 0]$$\n" + ], + "metadata": { + "collapsed": false + }, + "id": "716c02a70cfecdff" + }, + { + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "from essm_jax.pytee_utils import pytree_unravel\n", + "from typing import NamedTuple\n", + "\n", + "\n", + "class State(NamedTuple):\n", + " \"\"\"\n", + " State for a single timestep of the Kalman filter.\n", + " \"\"\"\n", + " x: jax.Array # [n]\n", + " v: jax.Array # [n]\n", + " a: jax.Array # [n]\n", + " M: jax.Array # [n]\n", + "\n", + "\n", + "class Observation(NamedTuple):\n", + " \"\"\"\n", + " Observables for a single timestep of the Kalman filter.\n", + " \"\"\"\n", + " x: jax.Array # [n]\n", + "\n", + "\n", + "n = 1 # Single dimension\n", + "\n", + "\n", + "@dataclasses.dataclass(eq=False)\n", + "class ExcitedDampedHarmonicOscillator:\n", + " n: int = 1\n", + " dt: float = 0.01\n", + " sigma_f: float = 0.1\n", + " sigma_x: float = 0.1\n", + " zeta: float = 0.1\n", + " omega: float = 1.\n", + " m_bar: float = 1.\n", + "\n", + " def __post_init__(self):\n", + " self.c = 2 * self.zeta * self.omega\n", + " self.k = self.omega ** 2\n", + " self.M_bar = jnp.log(self.m_bar)\n", + "\n", + " example_state = State(\n", + " x=jnp.ones((n,)),\n", + " v=jnp.ones((n,)),\n", + " a=jnp.ones((n,)),\n", + " M=jnp.ones((n,))\n", + " )\n", + "\n", + " example_observation = Observation(\n", + " x=jnp.ones((n,))\n", + " )\n", + "\n", + " self.state_ravel_fn, self.state_unravel_fn = pytree_unravel(example_state)\n", + " self.obs_ravel_fn, self.obs_unravel_fn = pytree_unravel(example_observation)\n", + "\n", + " def batched_state_ravel_fn(self, state: State) -> jax.Array:\n", + " \"\"\"\n", + " Ravel a state with batched time axis into batched flat states.\n", + " \"\"\"\n", + " # This is needed because calling ravel is only defined on the per element basis.\n", + " return jax.vmap(self.state_ravel_fn)(state)\n", + "\n", + " def batched_state_unravel_fn(self, flat_state: jax.Array) -> State:\n", + " \"\"\"\n", + " Unravel a batched flat state into a state with batched time axis.\n", + " \"\"\"\n", + " # This is needed because calling unravel is only defined on the per element basis.\n", + " return jax.vmap(self.state_unravel_fn)(flat_state)\n", + "\n", + " def batched_observables_ravel_fn(self, observation: Observation) -> jax.Array:\n", + " \"\"\"\n", + " Ravel observables with batched time axis into batched flat observables.\n", + " \"\"\"\n", + " # This is needed because calling ravel is only defined on the per element basis.\n", + " return jax.vmap(self.obs_ravel_fn)(observation)\n", + "\n", + " def batched_observables_unravel_fn(self, flat_observables: jax.Array) -> Observation:\n", + " \"\"\"\n", + " Unravel a batched flat observables into observables with batched time axis.\n", + " \"\"\"\n", + " # This is needed because calling unravel is only defined on the per element basis.\n", + " return jax.vmap(self.obs_unravel_fn)(flat_observables)\n", + "\n", + " def transition_fn(self, z: jax.Array, t: jax.Array, t_next) -> (tfpd.MultivariateNormalLinearOperator):\n", + " dt = t_next - t\n", + " state = self.state_unravel_fn(z)\n", + " next_state_mean = State(\n", + " x=state.x + state.v * dt,\n", + " v=state.v + (-self.c * state.v - self.k * state.x) * dt,\n", + " a=(-self.c * state.v - self.k * state.x),\n", + " M=self.M_bar * jnp.ones_like(state.M)\n", + " )\n", + " next_state_scale = State(\n", + " x=jnp.zeros_like(state.x),\n", + " v=self.sigma_f * jnp.sqrt(dt) * jnp.ones_like(state.v),\n", + " a=jnp.zeros_like(state.a),\n", + " M=jnp.zeros_like(state.M)\n", + " )\n", + " return tfpd.MultivariateNormalDiag(self.state_ravel_fn(next_state_mean), self.state_ravel_fn(next_state_scale))\n", + "\n", + " def observation_fn(self, z: jax.Array, t: jax.Array) -> tfpd.MultivariateNormalLinearOperator:\n", + " state = self.state_unravel_fn(z)\n", + " obs_mean = Observation(\n", + " x=state.x,\n", + " )\n", + " obs_scale = Observation(\n", + " x=self.sigma_x * jnp.ones_like(state.x),\n", + " )\n", + " return tfpd.MultivariateNormalDiag(self.obs_ravel_fn(obs_mean), self.obs_ravel_fn(obs_scale))\n", + "\n", + " def create_initial_state_prior(self) -> tfpd.MultivariateNormalLinearOperator:\n", + " initial_state = State(\n", + " x=jnp.zeros((n,)),\n", + " v=jnp.zeros((n,)),\n", + " a=jnp.zeros((n,)),\n", + " M=self.M_bar * jnp.ones((n,))\n", + " )\n", + " initial_state_scale = State(\n", + " x=jnp.zeros((n,)),\n", + " v=jnp.zeros((n,)),\n", + " a=jnp.zeros((n,)),\n", + " M=jnp.zeros((n,))\n", + " )\n", + " return tfpd.MultivariateNormalDiag(self.state_ravel_fn(initial_state), self.state_ravel_fn(initial_state_scale))\n", + "\n", + " def build_essm(self):\n", + " return ExtendedStateSpaceModel(\n", + " transition_fn=self.transition_fn,\n", + " observation_fn=self.observation_fn,\n", + " initial_state_prior=self.create_initial_state_prior(),\n", + " more_data_than_params=False,\n", + " dt=self.dt\n", + " )\n", + "\n", + "\n", + "model = ExcitedDampedHarmonicOscillator(\n", + " n=1,\n", + " dt=0.01,\n", + " sigma_f=0.01,\n", + " sigma_x=0.001,\n", + " zeta=0.1,\n", + " omega=1.,\n", + " m_bar=1.\n", + ")\n", + "\n", + "essm = model.build_essm()\n", + "\n", + "samples = essm.sample(key=jax.random.PRNGKey(0), num_time=1000)\n", + "\n", + "obs = model.batched_observables_unravel_fn(samples.observation)\n", + "state = model.batched_state_unravel_fn(samples.latent)\n", + "plt.plot(samples.t, state.x, label='x', c='r')\n", + "plt.plot(samples.t, state.v, label='v', c='b')\n", + "plt.plot(samples.t, state.a, label='a', c='g')\n", + "plt.plot(samples.t, state.M, label='M', c='k')\n", + "plt.plot(samples.t, obs.x, label='observed', alpha=0.5)\n", + "\n", + "filter_result = essm.forward_filter(observations=samples.observation)\n", + "\n", + "future_samples = essm.forward_simulate(key=jax.random.PRNGKey(0), num_time=100, filter_result=filter_result)\n", + "future_latent = model.batched_state_unravel_fn(future_samples.latent)\n", + "plt.plot(future_samples.t, future_latent.x, c='r')\n", + "plt.plot(future_samples.t, future_latent.v, c='b')\n", + "plt.plot(future_samples.t, future_latent.a, c='g')\n", + "plt.plot(future_samples.t, future_latent.M, c='k')\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "filter_latent = model.batched_state_unravel_fn(filter_result.filtered_mean)\n", + "plt.plot(filter_result.t, filter_latent.x, label='x', c='r')\n", + "plt.plot(filter_result.t, filter_latent.v, label='v', c='b')\n", + "plt.plot(filter_result.t, filter_latent.a, label='a', c='g')\n", + "plt.plot(filter_result.t, filter_latent.M, label='M', c='k')\n", + "\n", + "filter_state = essm.create_filter_state(filter_result=filter_result)\n", + "for _ in range(100):\n", + " filter_state = essm.incremental_predict(filter_state=filter_state)\n", + " latent = model.state_unravel_fn(filter_state.filtered_mean)\n", + " plt.scatter(filter_state.t, latent.x, c='r', s=1)\n", + " plt.scatter(filter_state.t, latent.v, c='b', s=1)\n", + " plt.scatter(filter_state.t, latent.a, c='g', s=1)\n", + " plt.scatter(filter_state.t, latent.M, c='k', s=1)\n", + "\n", + "plt.legend()\n", + "plt.show()\n", + "\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:31:00.060073207Z", + "start_time": "2024-08-14T22:30:53.098775507Z" + } + }, + "id": "5f0b000d8cfe6d2b", + "execution_count": 4 + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-08-14T22:28:27.399534612Z" + } + }, + "id": "6626b28fc500ba6d", + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/online_filtering.ipynb b/docs/examples/online_filtering.ipynb new file mode 100644 index 0000000..0658db6 --- /dev/null +++ b/docs/examples/online_filtering.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ef3eaff3797489e6", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:52.089737670Z", + "start_time": "2024-08-14T22:35:51.100189045Z" + } + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import tensorflow_probability.substrates.jax as tfp\n", + "import pylab as plt\n", + "\n", + "from essm_jax.essm import ExtendedStateSpaceModel\n", + "\n", + "tfpd = tfp.distributions\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's define a non-linear transition function that forces the state proportionally to its magnitude,\n", + "\n", + "$$p(z_{t+1} | z_t, t) = \\mathcal{N}[T(z_t, t), \\sigma_t^2]$$\n", + "\n", + "with,\n", + "\n", + "$$T(z_t, t) = z_t \\left(1 + |z_t| \\sin \\left(2 \\pi \\frac{t}{10}\\right)\\right)$$\n", + "\n", + "and \n", + "\n", + "$$\\sigma_t = 0.1$$\n", + "\n", + "Now for the observation function, let's define also a non-linear one, that takes the absolute value.\n", + "\n", + "$$p(x_t | z_t, t) = \\mathcal{N}[O(z_t, t), \\epsilon_t^2]$$\n", + "\n", + "with \n", + "\n", + "$$O(z_t, t) = |z_t|$$\n", + "\n", + "and noise that oscillates with time,\n", + "\n", + "$$\\epsilon_t = 0.01 + 0.1 \\cos\\left(2\\pi \\frac{t}{5}\\right)$$\n" + ], + "metadata": { + "collapsed": false + }, + "id": "716c02a70cfecdff" + }, + { + "cell_type": "code", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "\n", + "def transition_fn(z, t, t_next):\n", + " mean = z * (1 + jnp.abs(z) * jnp.sin(2 * jnp.pi * t / 10))\n", + " scale = 0.01 * jnp.ones(np.size(z))\n", + " return tfpd.MultivariateNormalDiag(mean, scale)\n", + "\n", + "\n", + "def observation_fn(z, t):\n", + " mean = jnp.abs(z)\n", + " scale = 0.001 * jnp.ones(np.size(z)) + 0.01 * (2 * jnp.pi * t / 5.)\n", + " return tfpd.MultivariateNormalDiag(mean, scale)\n", + "\n", + "\n", + "n = 1\n", + "\n", + "initial_state_prior = tfpd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n))\n", + "\n", + "essm = ExtendedStateSpaceModel(\n", + " transition_fn=transition_fn,\n", + " observation_fn=observation_fn,\n", + " initial_state_prior=initial_state_prior,\n", + " more_data_than_params=False, # if observation is bigger than latent we can speed it up.\n", + " dt=1.\n", + ")\n", + "samples = essm.sample(jax.random.PRNGKey(0), 100)\n", + "\n", + "plt.plot(samples.t, samples.observation)\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:52.766401446Z", + "start_time": "2024-08-14T22:35:52.094140163Z" + } + }, + "id": "5f0b000d8cfe6d2b", + "execution_count": 2 + }, + { + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "filter_result = essm.forward_filter(samples.observation)\n", + "\n", + "filter_state = essm.create_initial_filter_state()\n", + "\n", + "\n", + "for i in range(100):\n", + " filter_state = essm.incremental_predict(filter_state)\n", + " # incorperate new data\n", + " filter_state, _ = essm.incremental_update(filter_state, samples.observation[i])\n", + " \n", + " plt.scatter(filter_state.t, filter_state.filtered_mean, c='black')\n", + " \n", + "\n", + "plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')\n", + "plt.legend()\n", + "plt.show()" + ], + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:58.070899470Z", + "start_time": "2024-08-14T22:35:52.776945243Z" + } + }, + "id": "initial_id", + "execution_count": 3 + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:58.071206840Z", + "start_time": "2024-08-14T22:35:58.066994695Z" + } + }, + "id": "6626b28fc500ba6d", + "execution_count": 3 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/essm_jax/essm.py b/essm_jax/essm.py index 71fb768..93ba669 100644 --- a/essm_jax/essm.py +++ b/essm_jax/essm.py @@ -87,19 +87,19 @@ class ExtendedStateSpaceModel: Args: transition_fn: A function that computes the state transition distribution - p(z[t] | z[t-1], t). Must return a MultivariateNormalLinearOperator. - Call signature is `transition_fn(z[t-1], t)`, where z[t-1] is the previous state. + p(z(t'), t' | z(t), t). Must return a MultivariateNormalLinearOperator. + Call signature is `transition_fn(z(t), t, t')`, where z(t) is the previous state. observation_fn: A function that computes the observation distribution - p(x[t] | z[t], t). Must return a MultivariateNormalLinearOperator. - Call signature is `observation_fn(z[t], t)`, where z[t] is the current state. - Note: t in [1, num_time] is the observation time index, with t=0 being the initial state. - initial_state_prior: A distribution over the initial state p(z[0]). + p(x(t) | z(t), t). Must return a MultivariateNormalLinearOperator. + Call signature is `observation_fn(z(t), t)`, where z(t) is the current state. + Note: t is the observation time, with t=t0 being the initial state. + initial_state_prior: A distribution over the initial state p(z(t0)). Must be a MultivariateNormalLinearOperator. more_data_than_params: If True, the observation function has more outputs than inputs. materialise_jacobians: If True, the Jacobians are materialised as dense matrices. - dt: The time step size, default is 1. + dt: The time step size, default is 1. t[i] = t0 + (i+1) * dt """ - transition_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] + transition_fn: Callable[[jax.Array, jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] observation_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] initial_state_prior: tfpd.MultivariateNormalLinearOperator more_data_than_params: bool = False @@ -119,9 +119,9 @@ def __post_init__(self): self.latent_size = np.size(_initial_state_prior_mean) self.latent_shape = np.shape(_initial_state_prior_mean) - def get_transition_jacobian(self, t: jax.Array) -> JVPLinearOp: + def get_transition_jacobian(self, t: jax.Array, t_next: jax.Array) -> JVPLinearOp: def _transition_fn(z): - return self.transition_fn(z, t).mean() + return self.transition_fn(z, t, t_next).mean() return JVPLinearOp(_transition_fn, more_outputs_than_inputs=False) @@ -134,15 +134,15 @@ def _observation_fn(z): more_data_than_params = self.latent_size < observation_size return JVPLinearOp(_observation_fn, more_outputs_than_inputs=more_data_than_params) - def transition_matrix(self, z, t): - Fop = self.get_transition_jacobian(t) + def transition_matrix(self, z, t, t_next): + Fop = self.get_transition_jacobian(t, t_next) return Fop(z).to_dense() def observation_matrix(self, z, t): Hop = self.get_observation_jacobian(t) return Hop(z).to_dense() - def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleResult: + def sample(self, key, num_time: int, t0: Union[jax.Array, float] = 0.) -> SampleResult: """ Sample from the model. @@ -162,19 +162,22 @@ def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleRes # observation: S -> O def _sample_latents_op(latent, y): - (key, t) = y + (key, t, t_next) = y new_latent_key, obs_key = jax.random.split(key, 2) - transition_dist = self.transition_fn(latent, t) + transition_dist = self.transition_fn(latent, t, t_next) new_latent = transition_dist.sample(seed=new_latent_key) observation_dist = self.observation_fn(new_latent, t) new_observation = observation_dist.sample(seed=obs_key) - return new_latent, SampleResult(t=t, latent=new_latent, observation=new_observation) + return new_latent, SampleResult(t=t_next, latent=new_latent, observation=new_observation) - # Sample at t0 + # Sample at t0 forming initial state init = self.initial_state_prior.sample(seed=init_key) + t_from = jnp.arange(0, num_time) * self.dt + t0 + t_to = t_from + self.dt xs = ( jax.random.split(latent_key, num_time), - jnp.arange(1, num_time + 1) * self.dt + t0 + t_from, + t_to ) _, samples = lax.scan( _sample_latents_op, @@ -239,7 +242,8 @@ def forward_simulate(self, key: jax.Array, num_time: int, observation_fn=self.observation_fn, initial_state_prior=initial_state_prior, more_data_than_params=self.more_data_than_params, - materialise_jacobians=self.materialise_jacobians + materialise_jacobians=self.materialise_jacobians, + dt=self.dt ) return new_essm.sample(key=key, num_time=num_time, t0=t0) @@ -327,25 +331,42 @@ def incremental_predict(self, filter_state: IncrementalFilterState) -> Increment """ # Predict step, compute p(z[t+1] | x[:t]) - t = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype) + t_next = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype) - Fop = self.get_transition_jacobian(t=t) + Fop = self.get_transition_jacobian(t=filter_state.t, t_next=t_next) F = Fop(filter_state.filtered_mean) if self.materialise_jacobians: F = F.to_dense() - predicted_dist = self.transition_fn(filter_state.filtered_mean, t) + predicted_dist = self.transition_fn(filter_state.filtered_mean, filter_state.t, t_next) predicted_mean = predicted_dist.mean() # [latent_size] Q = predicted_dist.covariance() # [latent_size, latent_size] predicted_cov = F @ filter_state.filtered_cov @ F.T + Q # [latent_size, latent_size] return IncrementalFilterState( - t=t, + t=t_next, log_cumulative_marginal_likelihood=filter_state.log_cumulative_marginal_likelihood, filtered_mean=predicted_mean, filtered_cov=predicted_cov ) + def create_filter_state(self, filter_result: FilterResult) -> IncrementalFilterState: + """ + Create an incremental filter state from a filter result. + + Args: + filter_result: the filter result + + Returns: + the incremental filter state + """ + return IncrementalFilterState( + t=filter_result.t[-1], + log_cumulative_marginal_likelihood=filter_result.log_cumulative_marginal_likelihood[-1], + filtered_mean=filter_result.filtered_mean[-1], + filtered_cov=filter_result.filtered_cov[-1] + ) + def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> IncrementalFilterState: """ Create an incremental filter at the initial time. @@ -358,20 +379,11 @@ def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> Incre """ # Push forward initial state to create p(z[1] | z[0]) t0 = jnp.asarray(t0, jnp.float32) - t1 = t0 + jnp.asarray(self.dt, t0.dtype) - init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) - - init_Fop = self.get_transition_jacobian(t1) - init_F = init_Fop(self.initial_state_prior.mean()) - if self.materialise_jacobians: - init_F = init_F.to_dense() - init_predicted_mean = init_predict_dist.mean() - init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() return IncrementalFilterState( - t=t1, + t=t0, log_cumulative_marginal_likelihood=jnp.asarray(0.), - filtered_mean=init_predicted_mean, - filtered_cov=init_predicted_cov + filtered_mean=self.initial_state_prior.mean(), + filtered_cov=self.initial_state_prior.covariance() ) def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = None, @@ -433,6 +445,8 @@ def _filter_op(filter_state: IncrementalFilterState, y: YType) -> Tuple[Incremen ) filter_state = self.create_initial_filter_state(t0=t0) + filter_state = self.incremental_predict(filter_state) # Advance to first update time + if mask is None: _mask = jnp.zeros(num_time, dtype=jnp.bool_) # dummy variable (we skip the mask select) else: @@ -503,33 +517,40 @@ class Carry(NamedTuple): smoothed_cov=filter_result.predicted_cov[-1] # [latent_size, latent_size] covariance of p(z[T] | x[:T]) ) - def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: + class XType(NamedTuple): + t: jax.Array # time t + filtered_mean: jax.Array # [latent_size] mean of p(z[t] | x[:t]) + filtered_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t] | x[:t]) + predicted_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t]) + predicted_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t]) + + def _smooth_op(carry: Carry, x: XType) -> Tuple[Carry, SmoothingResult]: """ A single step of the backward equations. """ - Fop = self.get_transition_jacobian(t=y.t) - F = Fop(y.filtered_mean) + Fop = self.get_transition_jacobian(t=x.t - self.dt, t_next=x.t) + F = Fop(x.filtered_mean) if self.materialise_jacobians: F = F.to_dense() - # Compute the RTS smoother gain + # Compute the C # J = y.filtered_cov @ F.T @ jnp.linalg.inv(y.predicted_cov) - predicted_cov_chol = lax.linalg.cholesky(y.predicted_cov, + predicted_cov_chol = lax.linalg.cholesky(_efficient_add_scalar_diag(x.predicted_cov, 1e-6), symmetrize_input=False) # Possibly need to add a small diagonal jitter - tmp_F_P = F @ y.filtered_cov - J = hpsd_solve(y.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T + tmp_F_P = F @ x.filtered_cov + J = hpsd_solve(x.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T # Update the state estimate - smoothed_mean = y.filtered_mean + J @ (carry.smoothed_mean - y.predicted_mean) + smoothed_mean = x.filtered_mean + J @ (carry.smoothed_mean - x.predicted_mean) # Update the state covariance - smoothed_cov = y.filtered_cov + J @ (carry.smoothed_cov - y.predicted_cov) @ J.T + smoothed_cov = x.filtered_cov + J @ (carry.smoothed_cov - x.predicted_cov) @ J.T # Push-forward the smoothed distribution to the observation space - observation_dist = self.observation_fn(smoothed_mean, y.t) + observation_dist = self.observation_fn(smoothed_mean, x.t) R = observation_dist.covariance() - Hop = self.get_observation_jacobian(t=y.t, observation_size=y.observation_mean.size) + Hop = self.get_observation_jacobian(t=x.t, observation_size=np.shape(filter_result.observation_mean)[1]) H = Hop(smoothed_mean) if self.materialise_jacobians: H = H.to_dense() @@ -540,49 +561,46 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: smoothed_mean=smoothed_mean, smoothed_cov=smoothed_cov ), SmoothingResult( - t=y.t, + t=x.t, smoothed_mean=smoothed_mean, smoothed_cov=smoothed_cov, smoothed_obs_mean=smoothed_obs_mean, smoothed_obs_cov=smoothed_obs_cov ) + xs = XType( + t=filter_result.t, + filtered_mean=filter_result.filtered_mean, + filtered_cov=filter_result.filtered_cov, + predicted_mean=filter_result.predicted_mean, + predicted_cov=filter_result.predicted_cov + ) + if include_prior: + # prepend the initial state + t0 = xs.t[0] - self.dt + init_filter_state = self.create_initial_filter_state(t0=t0) + init_predict_state = self.incremental_predict(init_filter_state) + xs = XType( + t=jnp.concatenate([jnp.asarray(t0)[None], xs.t], axis=0), + filtered_mean=jnp.concatenate([init_filter_state.filtered_mean[None], xs.filtered_mean], axis=0), + filtered_cov=jnp.concatenate([init_filter_state.filtered_cov[None], xs.filtered_cov], axis=0), + predicted_mean=jnp.concatenate([init_predict_state.filtered_mean[None], xs.predicted_mean], axis=0), + predicted_cov=jnp.concatenate([init_predict_state.filtered_cov[None], xs.predicted_cov], axis=0) + ) + final_carry, smooth_results = lax.scan( f=_smooth_op, init=init_carry, - xs=filter_result, + xs=xs, reverse=True ) if include_prior: - # Transition prior to compute predicted p(z1 | z0) - t1 = filter_result.t[0] - init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) - init_Fop = self.get_transition_jacobian(t=t1) - init_F = init_Fop(self.initial_state_prior.mean()) - if self.materialise_jacobians: - init_F = init_F.to_dense() - init_predicted_mean = init_predict_dist.mean() - init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() - - t0 = t1 - jnp.asarray(self.dt, t1.dtype) - y = FilterResult( - t=t0, - log_cumulative_marginal_likelihood=jnp.asarray(0.), - filtered_mean=self.initial_state_prior.mean(), - filtered_cov=self.initial_state_prior.covariance(), - predicted_mean=init_predicted_mean, - predicted_cov=init_predicted_cov, - observation_mean=filter_result.observation_mean[0], # dummy unused - observation_cov=filter_result.observation_cov[0] # dummy unused - ) - final_initial_prior_carry, _ = _smooth_op( - carry=final_carry, - y=y - ) + # Trim + smooth_results = jax.tree.map(lambda x: x[1:], smooth_results) smoothed_prior = InitialPrior( - mean=final_initial_prior_carry.smoothed_mean, - covariance=final_initial_prior_carry.smoothed_cov + mean=final_carry.smoothed_mean, + covariance=final_carry.smoothed_cov ) return smooth_results, smoothed_prior diff --git a/essm_jax/tests/test_essm.py b/essm_jax/tests/test_essm.py index da7f0cf..f55b72e 100644 --- a/essm_jax/tests/test_essm.py +++ b/essm_jax/tests/test_essm.py @@ -18,7 +18,7 @@ def test_extended_state_space_model(): num_time = 10 - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = 2 * z cov = jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -137,7 +137,7 @@ def _compare(key): def test_jvp_essm(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = jnp.sin(2 * z) cov = jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -193,7 +193,7 @@ def observation_fn(z, t): def test_speed_test_jvp_essm(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = jnp.sin(2 * z + t) cov = jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -245,14 +245,14 @@ def observation_fn(z, t): def test_essm_forward_simulation(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = z + jnp.sin(2 * jnp.pi * t / 10 * z) cov = 0.1 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) def observation_fn(z, t): mean = z - cov = t * 0.01 * jnp.eye(np.size(z)) + cov = 0.01 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) n = 1 @@ -273,17 +273,19 @@ def observation_fn(z, t): # Suppose we only observe every 3rd observation mask = jnp.arange(T) % 3 != 0 - # Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1]) - log_prob = essm.log_prob(samples.observation, mask=mask) - print(log_prob) - # Filtered latent distribution, p(z[t] | x[:t]) filter_result = essm.forward_filter(samples.observation, mask=mask) + assert np.all(np.isfinite(filter_result.log_cumulative_marginal_likelihood)) + assert np.all(np.isfinite(filter_result.filtered_mean)) + + # Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1]) + log_prob = essm.log_prob(samples.observation, mask=mask) + assert log_prob == filter_result.log_cumulative_marginal_likelihood[-1] # Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations # Including new estimate for prior state p(z[0]) smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True) - print(smooth_result) + assert np.all(np.isfinite(smooth_result.smoothed_mean)) # Forward simulate the model forward_samples = essm.forward_simulate( @@ -322,7 +324,7 @@ def test__efficienct_add_scalar_diag(): def test_incremental_filtering(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = z + z * jnp.sin(2 * jnp.pi * t / 10) cov = 0.1 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -349,26 +351,14 @@ def observation_fn(z, t): filter_state = essm.create_initial_filter_state() - import pylab as plt - for i in range(100): + filter_state = essm.incremental_predict(filter_state) filter_state, _ = essm.incremental_update(filter_state, samples.observation[i]) - plt.scatter(filter_state.t, filter_state.filtered_mean, c='black') assert filter_state.t == filter_result.t[i] - filter_state = essm.incremental_predict(filter_state) - # print(i, np.abs( - # filter_state.log_cumulative_marginal_likelihood - filter_result.log_cumulative_marginal_likelihood[i])) - # print(i, np.max(np.abs(filter_state.filtered_mean - filter_result.predicted_mean[i]))) - # print(i, np.max(np.abs(filter_state.filtered_cov - filter_result.predicted_cov[i]))) - # print(i, np.max(np.abs(filter_state.filtered_cov))) np.testing.assert_allclose(filter_state.log_cumulative_marginal_likelihood, filter_result.log_cumulative_marginal_likelihood[i], atol=1e-5) - np.testing.assert_allclose(filter_state.filtered_mean, filter_result.predicted_mean[i], atol=1e-5) - np.testing.assert_allclose(filter_state.filtered_cov, filter_result.predicted_cov[i], atol=1e-5) - - plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent') - plt.legend() - plt.show() + np.testing.assert_allclose(filter_state.filtered_mean, filter_result.filtered_mean[i], atol=1e-5) + np.testing.assert_allclose(filter_state.filtered_cov, filter_result.filtered_cov[i], atol=1e-5) @pytest.mark.parametrize('use_sparse', [False, True]) @@ -386,7 +376,7 @@ def test_performance_sparse(use_sparse: bool): else: m = jnp.asarray(m) - def transition_fn(z, t): + def transition_fn(z, t, t_next): if use_sparse: mean = matvec_sparse(m, z) else: From d4948f66886b0d657c4c6ec663deea6b13d50da9 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:36:57 +0300 Subject: [PATCH 6/8] * bump to 1.0.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 83220de..bf6d268 100755 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ long_description = fh.read() setup(name='essm_jax', - version='1.0.0', + version='1.0.1', description='Extended State Spapce Model in JAX', long_description=long_description, long_description_content_type="text/markdown", From 0c51ac825aca8f07f6e6bef591bf539655110273 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:39:47 +0300 Subject: [PATCH 7/8] * fix test --- essm_jax/tests/test_essm.py | 2 +- essm_jax/tests/test_jvp_op.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/essm_jax/tests/test_essm.py b/essm_jax/tests/test_essm.py index f55b72e..e9ae0b8 100644 --- a/essm_jax/tests/test_essm.py +++ b/essm_jax/tests/test_essm.py @@ -3,9 +3,9 @@ import jax import pytest +jax.config.update('jax_enable_x64', True) from essm_jax.sparse import create_sparse_rep, matvec_sparse -jax.config.update('jax_enable_x64', True) import numpy as np import tensorflow_probability.substrates.jax as tfp from jax import numpy as jnp diff --git a/essm_jax/tests/test_jvp_op.py b/essm_jax/tests/test_jvp_op.py index fe3aaf6..5022f45 100644 --- a/essm_jax/tests/test_jvp_op.py +++ b/essm_jax/tests/test_jvp_op.py @@ -1,4 +1,7 @@ +import jax import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) import numpy as np import pytest From 0aa3bd4d154ea6805067049eca5dd55a365939c2 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:43:29 +0300 Subject: [PATCH 8/8] * fix test --- essm_jax/tests/test_jvp_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/essm_jax/tests/test_jvp_op.py b/essm_jax/tests/test_jvp_op.py index 5022f45..1799632 100644 --- a/essm_jax/tests/test_jvp_op.py +++ b/essm_jax/tests/test_jvp_op.py @@ -16,7 +16,7 @@ def test_jvp_linear_op(): def fn(x): return jnp.asarray([jnp.sum(jnp.sin(x) ** i) for i in range(m)]) - x = jnp.arange(n).astype(jnp.float32) + x = jnp.arange(n).astype(float) jvp_op = JVPLinearOp(fn) jvp_op = jvp_op(x) @@ -73,8 +73,8 @@ def test_multiple_primals(init_primals: bool): def fn(x, y): return jnp.stack([x * y, y, -y], axis=-1) # [n, 3] - x = jnp.arange(n).astype(jnp.float32) - y = jnp.arange(n).astype(jnp.float32) + x = jnp.arange(n).astype(float) + y = jnp.arange(n).astype(float) if init_primals: jvp_op = JVPLinearOp(fn, primals=(x, y)) else: