Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.1 #5

Merged
merged 9 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from essm_jax.essm import ExtendedStateSpaceModel
tfpd = tfp.distributions


def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -67,8 +67,7 @@ print(smooth_result)
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
observations=samples.observation,
mask=mask
filter_result=filter_result
)

import pylab as plt
Expand All @@ -86,9 +85,14 @@ plt.legend()
plt.show()
```

## Online Filtering

Take a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application.

# Change Log

13 August 2024: Initial release 1.0.0.
14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.

## Star History

Expand Down
317 changes: 317 additions & 0 deletions docs/examples/excitable_damped_harmonic_oscillator.ipynb

Large diffs are not rendered by default.

196 changes: 196 additions & 0 deletions docs/examples/online_filtering.ipynb

Large diffs are not rendered by default.

426 changes: 258 additions & 168 deletions essm_jax/essm.py

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions essm_jax/pytee_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Tuple, Callable, TypeVar

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax

PT = TypeVar('PT')


def pytree_unravel(example_tree: PT) -> Tuple[Callable[[PT], jax.Array], Callable[[jax.Array], PT]]:
"""
Returns functions to ravel and unravel a pytree.

Args:
example_tree: a pytree to be unravelled

Returns:
ravel_fun: a function to ravel the pytree
unravel_fun: a function to unravel
"""
leaf_list, tree_def = jax.tree.flatten(example_tree)

sizes = [np.size(leaf) for leaf in leaf_list]
shapes = [np.shape(leaf) for leaf in leaf_list]
dtypes = [leaf.dtype for leaf in leaf_list]

def ravel_fun(pytree: PT) -> jax.Array:
leaf_list, tree_def = jax.tree.flatten(pytree)
# promote types to common one
common_dtype = jnp.result_type(*dtypes)
leaf_list = [leaf.astype(common_dtype) for leaf in leaf_list]
return jnp.concatenate([lax.reshape(leaf, (size,)) for leaf, size in zip(leaf_list, sizes)])

def unravel_fun(flat_array: jax.Array) -> PT:
leaf_list = []
start = 0
for size, shape, dtype in zip(sizes, shapes, dtypes):
leaf_list.append(lax.reshape(flat_array[start:start + size], shape).astype(dtype))
start += size
return jax.tree.unflatten(tree_def, leaf_list)

return ravel_fun, unravel_fun
68 changes: 68 additions & 0 deletions essm_jax/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Tuple, NamedTuple

import jax
import jax.numpy as jnp
import numpy as np


class SparseRepresentation(NamedTuple):
shape: Tuple[int, ...]
rows: jax.Array
cols: jax.Array
vals: jax.Array


def create_sparse_rep(m: np.ndarray) -> SparseRepresentation:
"""
Creates a sparse rep from matrix m. Use in linear models with materialise_jacobian=False for 2x speed up.

Args:
m: [N,M] matrix

Returns:
sparse rep
"""
rows, cols = np.where(m)
sort_indices = np.lexsort((cols, rows))
rows = rows[sort_indices]
cols = cols[sort_indices]
return SparseRepresentation(
shape=np.shape(m),
rows=jnp.asarray(rows),
cols=jnp.asarray(cols),
vals=jnp.asarray(m[rows, cols])
)


def to_dense(m: SparseRepresentation, out: jax.Array | None = None) -> jax.Array:
"""
Form dense matrix.

Args:
m: sparse rep
out: output buffer

Returns:
out + M
"""
if out is None:
out = jnp.zeros(m.shape, m.vals.dtype)

return out.at[m.rows, m.cols].add(m.vals, unique_indices=True, indices_are_sorted=True)


def matvec_sparse(m: SparseRepresentation, v: jax.Array, out: jax.Array | None = None) -> jax.Array:
"""
Compute matmul for sparse rep. Speeds up large sparse linear models by about 2x.

Args:
m: sparse rep
v: vec
out: output buffer to add to.

Returns:
out + M @ v
"""
if out is None:
out = jnp.zeros(m.shape[0])
return out.at[m.rows].add(m.vals * v[m.cols], unique_indices=True, indices_are_sorted=True)
143 changes: 127 additions & 16 deletions essm_jax/tests/test_essm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import time

import jax
import pytest

jax.config.update('jax_enable_x64', True)
from essm_jax.sparse import create_sparse_rep, matvec_sparse

import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
Expand All @@ -13,7 +18,7 @@
def test_extended_state_space_model():
num_time = 10

def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = 2 * z
cov = jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -132,7 +137,7 @@ def _compare(key):


def test_jvp_essm():
def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = jnp.sin(2 * z)
cov = jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -188,7 +193,7 @@ def observation_fn(z, t):


def test_speed_test_jvp_essm():
def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = jnp.sin(2 * z + t)
cov = jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
Expand Down Expand Up @@ -217,19 +222,21 @@ def observation_fn(z, t):
)

sample = essm.sample(jax.random.PRNGKey(0), 1000)
filter_fn = jax.jit(lambda: essm.forward_filter(sample.observation)).lower().compile()
filter_jvp_fn = jax.jit(lambda: essm_jvp.forward_filter(sample.observation)).lower().compile()
filter_fn = jax.jit(
lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()
filter_jvp_fn = jax.jit(
lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()

t0 = time.time()
filter_results = filter_fn()
filter_results.t.block_until_ready()
filter_results.block_until_ready()
t1 = time.time()
dt1 = t1 - t0
print(f"Time for essm: {t1 - t0}")

t0 = time.time()
filter_results_jvp = filter_jvp_fn()
filter_results_jvp.t.block_until_ready()
filter_results_jvp.block_until_ready()
t1 = time.time()
dt2 = t1 - t0
print(f"Time for essm_jvp: {t1 - t0}")
Expand All @@ -238,14 +245,14 @@ def observation_fn(z, t):


def test_essm_forward_simulation():
def transition_fn(z, t):
def transition_fn(z, t, t_next):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

def observation_fn(z, t):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
cov = 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

n = 1
Expand All @@ -266,24 +273,25 @@ def observation_fn(z, t):
# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)

# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)
assert np.all(np.isfinite(filter_result.log_cumulative_marginal_likelihood))
assert np.all(np.isfinite(filter_result.filtered_mean))

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
assert log_prob == filter_result.log_cumulative_marginal_likelihood[-1]

# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)
assert np.all(np.isfinite(smooth_result.smoothed_mean))

# Forward simulate the model
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
observations=samples.observation,
mask=mask
filter_result=filter_result
)

try:
Expand Down Expand Up @@ -313,3 +321,106 @@ def test__efficienct_add_scalar_diag():
A = jnp.eye(100)
c = 1.
assert jnp.all(_efficient_add_scalar_diag(A, c) == A + c * jnp.eye(100))


def test_incremental_filtering():
def transition_fn(z, t, t_next):
mean = z + z * jnp.sin(2 * jnp.pi * t / 10)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

def observation_fn(z, t):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))

n = 1

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False, # Fast
more_data_than_params=False # if observation is bigger than latent we can speed it up.
)
samples = essm.sample(jax.random.PRNGKey(0), 100)

filter_result = essm.forward_filter(samples.observation)

filter_state = essm.create_initial_filter_state()

for i in range(100):
filter_state = essm.incremental_predict(filter_state)
filter_state, _ = essm.incremental_update(filter_state, samples.observation[i])
assert filter_state.t == filter_result.t[i]
np.testing.assert_allclose(filter_state.log_cumulative_marginal_likelihood,
filter_result.log_cumulative_marginal_likelihood[i], atol=1e-5)
np.testing.assert_allclose(filter_state.filtered_mean, filter_result.filtered_mean[i], atol=1e-5)
np.testing.assert_allclose(filter_state.filtered_cov, filter_result.filtered_cov[i], atol=1e-5)


@pytest.mark.parametrize('use_sparse', [False, True])
def test_performance_sparse(use_sparse: bool):
# Show that using sparse rep speeds up linear system
n = 128
k = 10
m = np.zeros((n, n))
rows = np.random.randint(n, size=k)
cols = np.random.randint(n, size=k)
m[rows, cols] += 1.

if use_sparse:
m = create_sparse_rep(m)
else:
m = jnp.asarray(m)

def transition_fn(z, t, t_next):
if use_sparse:
mean = matvec_sparse(m, z)
else:
mean = m @ z
scale = jnp.ones(np.size(z))
return tfpd.MultivariateNormalDiag(mean, scale)

def observation_fn(z, t):
mean = z
scale = jnp.ones(np.size(z))
return tfpd.MultivariateNormalDiag(mean, scale)

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=True
)

essm_jvp = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False
)

sample = essm.sample(jax.random.PRNGKey(0), 512)
filter_fn = jax.jit(
lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()
filter_jvp_fn = jax.jit(
lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile()

t0 = time.time()
filter_results = filter_fn()
filter_results.block_until_ready()
t1 = time.time()
dt1 = t1 - t0
print(f"Time for essm(use_sparse={use_sparse}): {t1 - t0}")

t0 = time.time()
filter_results_jvp = filter_jvp_fn()
filter_results_jvp.block_until_ready()
t1 = time.time()
dt2 = t1 - t0
print(f"Time for essm_jvp(use_sparse={use_sparse}): {t1 - t0}")
9 changes: 6 additions & 3 deletions essm_jax/tests/test_jvp_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)
import numpy as np
import pytest

Expand All @@ -13,7 +16,7 @@ def test_jvp_linear_op():
def fn(x):
return jnp.asarray([jnp.sum(jnp.sin(x) ** i) for i in range(m)])

x = jnp.arange(n).astype(jnp.float32)
x = jnp.arange(n).astype(float)

jvp_op = JVPLinearOp(fn)
jvp_op = jvp_op(x)
Expand Down Expand Up @@ -70,8 +73,8 @@ def test_multiple_primals(init_primals: bool):
def fn(x, y):
return jnp.stack([x * y, y, -y], axis=-1) # [n, 3]

x = jnp.arange(n).astype(jnp.float32)
y = jnp.arange(n).astype(jnp.float32)
x = jnp.arange(n).astype(float)
y = jnp.arange(n).astype(float)
if init_primals:
jvp_op = JVPLinearOp(fn, primals=(x, y))
else:
Expand Down
Loading
Loading