Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Nov 9, 2024
2 parents 76b3ae3 + 6bae822 commit 47798ea
Show file tree
Hide file tree
Showing 14 changed files with 1,445 additions and 159 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ JAXNS is:

What can you do with JAXNS?

1) Compute the Bayesian evidence of a model or hypothesis (the ultimate scientific method);
1) Compute the Bayesian evidence of a model or hypothesis (the ultimate scientific method);
2) Produce high-quality samples from the posterior distribution;
3) Easily handle degenerate difficult multi-modal posteriors;
4) Model both discrete and continuous priors and likelihoods;
Expand Down Expand Up @@ -359,6 +359,9 @@ before importing JAXNS.

# Change Log

9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour
`devices`.

4 Nov, 2024 -- JAXNS 2.6.4 released. Resolved bias when using phantom points.

1 Oct, 2024 -- JAXNS 2.6.3 released. Enable pytrees in context.
Expand Down
171 changes: 171 additions & 0 deletions benchmarks/gh136/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pkg_resources
import tensorflow_probability.substrates.jax as tfp
from jax._src.scipy.linalg import solve_triangular

from jaxns import Model, Prior

try:
from jaxns import NestedSampler
except ImportError:
from jaxns import DefaultNestedSampler as NestedSampler

tfpd = tfp.distributions


def build_run_model(num_slices, gradient_guided, ndims):
def run_model(key, prior_mu, prior_cov, data_mu, data_cov):
def log_normal(x, mean, cov):
L = jnp.linalg.cholesky(cov)
dx = x - mean
dx = solve_triangular(L, dx, lower=True)
return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
- 0.5 * dx @ dx

true_logZ = log_normal(data_mu, prior_mu, prior_cov + data_cov)

J = jnp.linalg.solve(data_cov + prior_cov, prior_cov)
post_mu = prior_mu + J.T @ (data_mu - prior_mu)
post_cov = prior_cov - J.T @ (prior_cov + data_cov) @ J

# print("True logZ={}".format(true_logZ))
# print("True post_mu={}".format(post_mu))
# print("True post_cov={}".format(post_cov))

# KL posterior || prior
dist_posterior = tfpd.MultivariateNormalFullCovariance(loc=post_mu, covariance_matrix=post_cov)
dist_prior = tfpd.MultivariateNormalFullCovariance(loc=prior_mu, covariance_matrix=prior_cov)
H_true = -tfp.distributions.kl_divergence(dist_posterior, dist_prior)

# print("True H={}".format(H_true))

def prior_model():
x = yield Prior(
tfpd.MultivariateNormalTriL(loc=prior_mu, scale_tril=jnp.linalg.cholesky(prior_cov)),
name='x')
return x

def log_likelihood(x):
return tfpd.MultivariateNormalTriL(loc=data_mu, scale_tril=jnp.linalg.cholesky(data_cov)).log_prob(x)

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

ns = NestedSampler(model=model, verbose=False, k=0, num_slices=num_slices, gradient_guided=gradient_guided)

termination_reason, state = ns(key)
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)

error = results.H_mean - H_true
log_Z_error = results.log_Z_mean - true_logZ
return results.H_mean, H_true, error, log_Z_error

return run_model


def get_data(ndims):
prior_mu = 15 * jnp.ones(ndims)
prior_cov = jnp.diag(jnp.ones(ndims)) ** 2

data_mu = jnp.zeros(ndims)
data_cov = jnp.diag(jnp.ones(ndims)) ** 2
data_cov = jnp.where(data_cov == 0., 0.99, data_cov)
return prior_mu, prior_cov, data_mu, data_cov


def main():
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 3
d = 32

data = get_data(d)

# Row 1: Plot logZ error for gradient guided vs baseline for different s, with errorbars
# Row 2: Plot H error for gradient guided vs baseline for different s, with errorbars
# Row 3: Plot time taken for gradient guided vs baseline for different s, with errorbars

s_array = [10, 20, 30, 40, 80, 120]

run_model_baseline_aot_array = [
jax.jit(build_run_model(num_slices=s, gradient_guided=False, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for
s in
s_array]
run_model_gg_aot_array = [
jax.jit(build_run_model(num_slices=s, gradient_guided=True, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for s
in
s_array]

H_errors = np.zeros((len(s_array), m, 2))
log_z_errors = np.zeros((len(s_array), m, 2))
dt = np.zeros((len(s_array), m, 2))

for s_idx in range(len(s_array)):
s = s_array[s_idx]
for i in range(m):
key = jax.random.PRNGKey(i * 42)
baseline_model = run_model_baseline_aot_array[s_idx]
gg_model = run_model_gg_aot_array[s_idx]
t0 = time.time()
H, H_true, H_error, log_Z_error = jax.block_until_ready(baseline_model(key, *data))
t1 = time.time()
dt[s_idx, i, 0] = t1 - t0
H_errors[s_idx, i, 0] = H_error
log_z_errors[s_idx, i, 0] = log_Z_error
print(f"Baseline: i={i} k=0 s={s} H={H} H_true={H_true} H_error={H_error} log_Z_error={log_Z_error}")
t0 = time.time()
H, H_true, H_error, log_Z_error = jax.block_until_ready(gg_model(key, *data))
t1 = time.time()
dt[s_idx, i, 1] = t1 - t0
H_errors[s_idx, i, 1] = H_error
log_z_errors[s_idx, i, 1] = log_Z_error
print(f"GG: i={i} k=0 s={s} H={H} H_true={H_true} H_error={H_error} log_Z_error={log_Z_error}")

fig, axs = plt.subplots(3, 1, figsize=(10, 15), sharex=True)
# Row 1
H_error_mean = np.mean(H_errors, axis=1) # [s, 2]
H_error_std = np.std(H_errors, axis=1) # [s, 2]
axs[0].plot(s_array, H_error_mean[:, 0], label="Baseline", c='b')
axs[0].plot(s_array, H_error_mean[:, 1], label="Gradient Guided", c='r')
axs[0].fill_between(s_array, H_error_mean[:, 0] - H_error_std[:, 0], H_error_mean[:, 0] + H_error_std[:, 0],
color='b', alpha=0.2)
axs[0].fill_between(s_array, H_error_mean[:, 1] - H_error_std[:, 1], H_error_mean[:, 1] + H_error_std[:, 1],
color='r', alpha=0.2)
axs[0].set_ylabel("H error")
axs[0].legend()

# Row 2
logZ_error_mean = np.mean(log_z_errors, axis=1) # [s, 2]
logZ_error_std = np.std(log_z_errors, axis=1) # [s, 2]
axs[1].plot(s_array, logZ_error_mean[:, 0], label="Baseline", c='b')
axs[1].plot(s_array, logZ_error_mean[:, 1], label="Gradient Guided", c='r')
axs[1].fill_between(s_array, logZ_error_mean[:, 0] - logZ_error_std[:, 0],
logZ_error_mean[:, 0] + logZ_error_std[:, 0], color='b', alpha=0.2)
axs[1].fill_between(s_array, logZ_error_mean[:, 1] - logZ_error_std[:, 1],
logZ_error_mean[:, 1] + logZ_error_std[:, 1], color='r', alpha=0.2)
axs[1].set_ylabel("logZ error")
axs[1].legend()

# Row 3
dt_mean = np.mean(dt, axis=1) # [s, 2]
dt_std = np.std(dt, axis=1) # [s, 2]
axs[2].plot(s_array, dt_mean[:, 0], label="Baseline", c='b')
axs[2].plot(s_array, dt_mean[:, 1], label="Gradient Guided", c='r')
axs[2].fill_between(s_array, dt_mean[:, 0] - dt_std[:, 0], dt_mean[:, 0] + dt_std[:, 0], color='b', alpha=0.2)
axs[2].fill_between(s_array, dt_mean[:, 1] - dt_std[:, 1], dt_mean[:, 1] + dt_std[:, 1], color='r', alpha=0.2)
axs[2].set_ylabel("Time taken")
axs[2].legend()
axs[2].set_xlabel(r"number of slices")

axs[0].set_title(f"Gradient guided vs baseline, D={d}, v{jaxns_version}")

plt.savefig(f"Gradient_guided_vs_baseline_D{d}_v{jaxns_version}.png")

plt.show()


if __name__ == '__main__':
main()
428 changes: 428 additions & 0 deletions docs/examples/efficient_parameter_estimation.ipynb

Large diffs are not rendered by default.

334 changes: 334 additions & 0 deletions docs/examples/gradient_guided.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions src/jaxns/internals/pytree_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import jax
import jax.numpy as jnp


def tree_dot(x, y):
dots = jax.tree.leaves(jax.tree.map(jnp.vdot, x, y))
return sum(dots[1:], start=dots[0])


def tree_norm(x):
norm2 = tree_dot(x, x)
if jnp.issubdtype(norm2.dtype, jnp.complexfloating):
return jnp.sqrt(norm2.real)
return jnp.sqrt(norm2)


def tree_mul(x, y):
return jax.tree.map(jax.lax.mul, x, y)


def tree_sub(x, y):
return jax.tree.map(jax.lax.sub, x, y)


def tree_div(x, y):
return jax.tree.map(jax.lax.div, x, y)
10 changes: 10 additions & 0 deletions src/jaxns/internals/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import jax
from jax import random, numpy as jnp
from jax.scipy import special

Expand Down Expand Up @@ -72,3 +73,12 @@ def resample_indicies(key: PRNGKey, log_weights: Optional[FloatArray] = None, S:
g = -random.gumbel(key, shape=(num_total,))
idx = jnp.argsort(g)[:S]
return idx


def sample_uniformly_masked(key, v, select_mask, num_samples: int, squeeze: bool = False):
# If no satisfied samples, then chooses randomly from them. Should never happen, but good to know.
log_weights = jnp.where(select_mask, 0., -jnp.inf)
sample_idxs = resample_indicies(key, log_weights=log_weights, S=num_samples, replace=True)
if squeeze:
sample_idxs = jnp.squeeze(sample_idxs)
return jax.tree.map(lambda x: x[sample_idxs], v)
3 changes: 2 additions & 1 deletion src/jaxns/nested_samplers/common/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ def create_init_termination_register() -> TerminationRegister:
plateau=jnp.asarray(False, jnp.bool_),
no_seed_points=jnp.asarray(False, jnp.bool_),
relative_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype),
absolute_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype)
absolute_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype),
peak_log_XL=jnp.asarray(-jnp.inf, mp_policy.measure_dtype)
)
8 changes: 8 additions & 0 deletions src/jaxns/nested_samplers/common/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def determine_termination(
8-bit -> 256: relative spread of live points < rtol
9-bit -> 512: absolute spread of live points < atol
10-bit -> 1024: no seed points left
11-bit -> 2048: XL < max(XL) * peak_XL_frac
Multiple flags are summed together
Expand Down Expand Up @@ -136,4 +137,11 @@ def _set_done_bit(bit_done, bit_reason, done, termination_reason):
done, termination_reason = _set_done_bit(termination_register.no_seed_points, 10,
done=done, termination_reason=termination_reason)

if term_cond.peak_XL_frac is not None:
log_XL = termination_register.evidence_calc.log_X_mean + termination_register.evidence_calc.log_L
peak_log_XL = termination_register.peak_log_XL
XL_reduction_reached = log_XL < peak_log_XL + jnp.log(term_cond.peak_XL_frac)
done, termination_reason = _set_done_bit(XL_reduction_reached, 11,
done=done, termination_reason=termination_reason)

return done, termination_reason
2 changes: 2 additions & 0 deletions src/jaxns/nested_samplers/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TerminationCondition(NamedTuple):
efficiency_threshold: Optional[FloatArray] = None
rtol: Optional[FloatArray] = None
atol: Optional[FloatArray] = None
peak_XL_frac: Optional[FloatArray] = None

def __and__(self, other):
return TerminationConditionConjunction(conds=[self, other])
Expand Down Expand Up @@ -134,6 +135,7 @@ class TerminationRegister(NamedTuple):
no_seed_points: BoolArray
relative_spread: FloatArray
absolute_spread: FloatArray
peak_log_XL: FloatArray


class NestedSamplerState(NamedTuple):
Expand Down
Loading

0 comments on commit 47798ea

Please sign in to comment.