Skip to content

Commit

Permalink
* Add dt
Browse files Browse the repository at this point in the history
* Add examples
  • Loading branch information
Joshuaalbert committed Aug 14, 2024
1 parent cb0ac81 commit bf755c1
Show file tree
Hide file tree
Showing 5 changed files with 624 additions and 103 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from essm_jax.essm import ExtendedStateSpaceModel
tfpd = tfp.distributions


def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down
317 changes: 317 additions & 0 deletions docs/examples/excitable_damped_harmonic_oscillator.ipynb

Large diffs are not rendered by default.

196 changes: 196 additions & 0 deletions docs/examples/online_filtering.ipynb

Large diffs are not rendered by default.

168 changes: 93 additions & 75 deletions essm_jax/essm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ class ExtendedStateSpaceModel:
Args:
transition_fn: A function that computes the state transition distribution
p(z[t] | z[t-1], t). Must return a MultivariateNormalLinearOperator.
Call signature is `transition_fn(z[t-1], t)`, where z[t-1] is the previous state.
p(z(t'), t' | z(t), t). Must return a MultivariateNormalLinearOperator.
Call signature is `transition_fn(z(t), t, t')`, where z(t) is the previous state.
observation_fn: A function that computes the observation distribution
p(x[t] | z[t], t). Must return a MultivariateNormalLinearOperator.
Call signature is `observation_fn(z[t], t)`, where z[t] is the current state.
Note: t in [1, num_time] is the observation time index, with t=0 being the initial state.
initial_state_prior: A distribution over the initial state p(z[0]).
p(x(t) | z(t), t). Must return a MultivariateNormalLinearOperator.
Call signature is `observation_fn(z(t), t)`, where z(t) is the current state.
Note: t is the observation time, with t=t0 being the initial state.
initial_state_prior: A distribution over the initial state p(z(t0)).
Must be a MultivariateNormalLinearOperator.
more_data_than_params: If True, the observation function has more outputs than inputs.
materialise_jacobians: If True, the Jacobians are materialised as dense matrices.
dt: The time step size, default is 1.
dt: The time step size, default is 1. t[i] = t0 + (i+1) * dt
"""
transition_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator]
transition_fn: Callable[[jax.Array, jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator]
observation_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator]
initial_state_prior: tfpd.MultivariateNormalLinearOperator
more_data_than_params: bool = False
Expand All @@ -119,9 +119,9 @@ def __post_init__(self):
self.latent_size = np.size(_initial_state_prior_mean)
self.latent_shape = np.shape(_initial_state_prior_mean)

def get_transition_jacobian(self, t: jax.Array) -> JVPLinearOp:
def get_transition_jacobian(self, t: jax.Array, t_next: jax.Array) -> JVPLinearOp:
def _transition_fn(z):
return self.transition_fn(z, t).mean()
return self.transition_fn(z, t, t_next).mean()

return JVPLinearOp(_transition_fn, more_outputs_than_inputs=False)

Expand All @@ -134,15 +134,15 @@ def _observation_fn(z):
more_data_than_params = self.latent_size < observation_size
return JVPLinearOp(_observation_fn, more_outputs_than_inputs=more_data_than_params)

def transition_matrix(self, z, t):
Fop = self.get_transition_jacobian(t)
def transition_matrix(self, z, t, t_next):
Fop = self.get_transition_jacobian(t, t_next)
return Fop(z).to_dense()

def observation_matrix(self, z, t):
Hop = self.get_observation_jacobian(t)
return Hop(z).to_dense()

def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleResult:
def sample(self, key, num_time: int, t0: Union[jax.Array, float] = 0.) -> SampleResult:
"""
Sample from the model.
Expand All @@ -162,19 +162,22 @@ def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleRes
# observation: S -> O

def _sample_latents_op(latent, y):
(key, t) = y
(key, t, t_next) = y
new_latent_key, obs_key = jax.random.split(key, 2)
transition_dist = self.transition_fn(latent, t)
transition_dist = self.transition_fn(latent, t, t_next)
new_latent = transition_dist.sample(seed=new_latent_key)
observation_dist = self.observation_fn(new_latent, t)
new_observation = observation_dist.sample(seed=obs_key)
return new_latent, SampleResult(t=t, latent=new_latent, observation=new_observation)
return new_latent, SampleResult(t=t_next, latent=new_latent, observation=new_observation)

# Sample at t0
# Sample at t0 forming initial state
init = self.initial_state_prior.sample(seed=init_key)
t_from = jnp.arange(0, num_time) * self.dt + t0
t_to = t_from + self.dt
xs = (
jax.random.split(latent_key, num_time),
jnp.arange(1, num_time + 1) * self.dt + t0
t_from,
t_to
)
_, samples = lax.scan(
_sample_latents_op,
Expand Down Expand Up @@ -239,7 +242,8 @@ def forward_simulate(self, key: jax.Array, num_time: int,
observation_fn=self.observation_fn,
initial_state_prior=initial_state_prior,
more_data_than_params=self.more_data_than_params,
materialise_jacobians=self.materialise_jacobians
materialise_jacobians=self.materialise_jacobians,
dt=self.dt
)
return new_essm.sample(key=key, num_time=num_time, t0=t0)

Expand Down Expand Up @@ -327,25 +331,42 @@ def incremental_predict(self, filter_state: IncrementalFilterState) -> Increment
"""
# Predict step, compute p(z[t+1] | x[:t])

t = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype)
t_next = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype)

Fop = self.get_transition_jacobian(t=t)
Fop = self.get_transition_jacobian(t=filter_state.t, t_next=t_next)
F = Fop(filter_state.filtered_mean)
if self.materialise_jacobians:
F = F.to_dense()

predicted_dist = self.transition_fn(filter_state.filtered_mean, t)
predicted_dist = self.transition_fn(filter_state.filtered_mean, filter_state.t, t_next)
predicted_mean = predicted_dist.mean() # [latent_size]
Q = predicted_dist.covariance() # [latent_size, latent_size]
predicted_cov = F @ filter_state.filtered_cov @ F.T + Q # [latent_size, latent_size]

return IncrementalFilterState(
t=t,
t=t_next,
log_cumulative_marginal_likelihood=filter_state.log_cumulative_marginal_likelihood,
filtered_mean=predicted_mean,
filtered_cov=predicted_cov
)

def create_filter_state(self, filter_result: FilterResult) -> IncrementalFilterState:
"""
Create an incremental filter state from a filter result.
Args:
filter_result: the filter result
Returns:
the incremental filter state
"""
return IncrementalFilterState(
t=filter_result.t[-1],
log_cumulative_marginal_likelihood=filter_result.log_cumulative_marginal_likelihood[-1],
filtered_mean=filter_result.filtered_mean[-1],
filtered_cov=filter_result.filtered_cov[-1]
)

def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> IncrementalFilterState:
"""
Create an incremental filter at the initial time.
Expand All @@ -358,20 +379,11 @@ def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> Incre
"""
# Push forward initial state to create p(z[1] | z[0])
t0 = jnp.asarray(t0, jnp.float32)
t1 = t0 + jnp.asarray(self.dt, t0.dtype)
init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1)

init_Fop = self.get_transition_jacobian(t1)
init_F = init_Fop(self.initial_state_prior.mean())
if self.materialise_jacobians:
init_F = init_F.to_dense()
init_predicted_mean = init_predict_dist.mean()
init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance()
return IncrementalFilterState(
t=t1,
t=t0,
log_cumulative_marginal_likelihood=jnp.asarray(0.),
filtered_mean=init_predicted_mean,
filtered_cov=init_predicted_cov
filtered_mean=self.initial_state_prior.mean(),
filtered_cov=self.initial_state_prior.covariance()
)

def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = None,
Expand Down Expand Up @@ -433,6 +445,8 @@ def _filter_op(filter_state: IncrementalFilterState, y: YType) -> Tuple[Incremen
)

filter_state = self.create_initial_filter_state(t0=t0)
filter_state = self.incremental_predict(filter_state) # Advance to first update time

if mask is None:
_mask = jnp.zeros(num_time, dtype=jnp.bool_) # dummy variable (we skip the mask select)
else:
Expand Down Expand Up @@ -503,33 +517,40 @@ class Carry(NamedTuple):
smoothed_cov=filter_result.predicted_cov[-1] # [latent_size, latent_size] covariance of p(z[T] | x[:T])
)

def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]:
class XType(NamedTuple):
t: jax.Array # time t
filtered_mean: jax.Array # [latent_size] mean of p(z[t] | x[:t])
filtered_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t] | x[:t])
predicted_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t])
predicted_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t])

def _smooth_op(carry: Carry, x: XType) -> Tuple[Carry, SmoothingResult]:
"""
A single step of the backward equations.
"""
Fop = self.get_transition_jacobian(t=y.t)
F = Fop(y.filtered_mean)
Fop = self.get_transition_jacobian(t=x.t - self.dt, t_next=x.t)
F = Fop(x.filtered_mean)
if self.materialise_jacobians:
F = F.to_dense()

# Compute the RTS smoother gain
# Compute the C
# J = y.filtered_cov @ F.T @ jnp.linalg.inv(y.predicted_cov)
predicted_cov_chol = lax.linalg.cholesky(y.predicted_cov,
predicted_cov_chol = lax.linalg.cholesky(_efficient_add_scalar_diag(x.predicted_cov, 1e-6),
symmetrize_input=False) # Possibly need to add a small diagonal jitter
tmp_F_P = F @ y.filtered_cov
J = hpsd_solve(y.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T
tmp_F_P = F @ x.filtered_cov
J = hpsd_solve(x.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T

# Update the state estimate
smoothed_mean = y.filtered_mean + J @ (carry.smoothed_mean - y.predicted_mean)
smoothed_mean = x.filtered_mean + J @ (carry.smoothed_mean - x.predicted_mean)

# Update the state covariance
smoothed_cov = y.filtered_cov + J @ (carry.smoothed_cov - y.predicted_cov) @ J.T
smoothed_cov = x.filtered_cov + J @ (carry.smoothed_cov - x.predicted_cov) @ J.T

# Push-forward the smoothed distribution to the observation space
observation_dist = self.observation_fn(smoothed_mean, y.t)
observation_dist = self.observation_fn(smoothed_mean, x.t)
R = observation_dist.covariance()

Hop = self.get_observation_jacobian(t=y.t, observation_size=y.observation_mean.size)
Hop = self.get_observation_jacobian(t=x.t, observation_size=np.shape(filter_result.observation_mean)[1])
H = Hop(smoothed_mean)
if self.materialise_jacobians:
H = H.to_dense()
Expand All @@ -540,49 +561,46 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]:
smoothed_mean=smoothed_mean,
smoothed_cov=smoothed_cov
), SmoothingResult(
t=y.t,
t=x.t,
smoothed_mean=smoothed_mean,
smoothed_cov=smoothed_cov,
smoothed_obs_mean=smoothed_obs_mean,
smoothed_obs_cov=smoothed_obs_cov
)

xs = XType(
t=filter_result.t,
filtered_mean=filter_result.filtered_mean,
filtered_cov=filter_result.filtered_cov,
predicted_mean=filter_result.predicted_mean,
predicted_cov=filter_result.predicted_cov
)
if include_prior:
# prepend the initial state
t0 = xs.t[0] - self.dt
init_filter_state = self.create_initial_filter_state(t0=t0)
init_predict_state = self.incremental_predict(init_filter_state)
xs = XType(
t=jnp.concatenate([jnp.asarray(t0)[None], xs.t], axis=0),
filtered_mean=jnp.concatenate([init_filter_state.filtered_mean[None], xs.filtered_mean], axis=0),
filtered_cov=jnp.concatenate([init_filter_state.filtered_cov[None], xs.filtered_cov], axis=0),
predicted_mean=jnp.concatenate([init_predict_state.filtered_mean[None], xs.predicted_mean], axis=0),
predicted_cov=jnp.concatenate([init_predict_state.filtered_cov[None], xs.predicted_cov], axis=0)
)

final_carry, smooth_results = lax.scan(
f=_smooth_op,
init=init_carry,
xs=filter_result,
xs=xs,
reverse=True
)

if include_prior:
# Transition prior to compute predicted p(z1 | z0)
t1 = filter_result.t[0]
init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1)
init_Fop = self.get_transition_jacobian(t=t1)
init_F = init_Fop(self.initial_state_prior.mean())
if self.materialise_jacobians:
init_F = init_F.to_dense()
init_predicted_mean = init_predict_dist.mean()
init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance()

t0 = t1 - jnp.asarray(self.dt, t1.dtype)
y = FilterResult(
t=t0,
log_cumulative_marginal_likelihood=jnp.asarray(0.),
filtered_mean=self.initial_state_prior.mean(),
filtered_cov=self.initial_state_prior.covariance(),
predicted_mean=init_predicted_mean,
predicted_cov=init_predicted_cov,
observation_mean=filter_result.observation_mean[0], # dummy unused
observation_cov=filter_result.observation_cov[0] # dummy unused
)
final_initial_prior_carry, _ = _smooth_op(
carry=final_carry,
y=y
)
# Trim
smooth_results = jax.tree.map(lambda x: x[1:], smooth_results)
smoothed_prior = InitialPrior(
mean=final_initial_prior_carry.smoothed_mean,
covariance=final_initial_prior_carry.smoothed_cov
mean=final_carry.smoothed_mean,
covariance=final_carry.smoothed_cov
)
return smooth_results, smoothed_prior

Expand Down
Loading

0 comments on commit bf755c1

Please sign in to comment.