Skip to content

Commit

Permalink
* bump to 2.6.6
Browse files Browse the repository at this point in the history
* improve plotting saving
* improve gradient guiding efficiency.
  • Loading branch information
Joshuaalbert committed Nov 13, 2024
1 parent 47798ea commit 9098db8
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 19 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ before importing JAXNS.

# Change Log

13 Nov, 2024 -- JAXNS 2.6.6 released. Minor improvements to plotting.

9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour
`devices`.

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 8 additions & 6 deletions benchmarks/gh136/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,25 @@ def get_data(ndims):

def main():
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 3
d = 32
m = 90
d = 16

data = get_data(d)

# Row 1: Plot logZ error for gradient guided vs baseline for different s, with errorbars
# Row 2: Plot H error for gradient guided vs baseline for different s, with errorbars
# Row 3: Plot time taken for gradient guided vs baseline for different s, with errorbars

s_array = [10, 20, 30, 40, 80, 120]
s_array = [0.5, 1, 2, 3, 4, 5]

run_model_baseline_aot_array = [
jax.jit(build_run_model(num_slices=s, gradient_guided=False, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for
jax.jit(build_run_model(num_slices=int(s * d), gradient_guided=False, ndims=d)).lower(jax.random.PRNGKey(0),
*data).compile() for
s in
s_array]
run_model_gg_aot_array = [
jax.jit(build_run_model(num_slices=s, gradient_guided=True, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for s
jax.jit(build_run_model(num_slices=int(s * d), gradient_guided=True, ndims=d)).lower(jax.random.PRNGKey(0),
*data).compile() for s
in
s_array]

Expand Down Expand Up @@ -158,7 +160,7 @@ def main():
axs[2].fill_between(s_array, dt_mean[:, 1] - dt_std[:, 1], dt_mean[:, 1] + dt_std[:, 1], color='r', alpha=0.2)
axs[2].set_ylabel("Time taken")
axs[2].legend()
axs[2].set_xlabel(r"number of slices")
axs[2].set_xlabel(r"s, slices per dim")

axs[0].set_title(f"Gradient guided vs baseline, D={d}, v{jaxns_version}")

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2024, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.6.5"
release = "2.6.6"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "jaxns"
version = "2.6.5"
version = "2.6.6"
description = "Nested Sampling in JAX"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
8 changes: 6 additions & 2 deletions src/jaxns/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def plot_diagnostics(results: NestedSamplerResults, save_name=None):
axs[5].set_xlabel(r'$- \log X$')
if save_name is not None:
fig.savefig(save_name, bbox_inches='tight', dpi=300, pad_inches=0.0)
plt.show()
plt.close(fig)
else:
plt.show()


def plot_cornerplot(results: NestedSamplerResults, variables: Optional[List[str]] = None,
Expand Down Expand Up @@ -279,7 +281,9 @@ def plot_cornerplot(results: NestedSamplerResults, variables: Optional[List[str]
# Save the figure
if save_name is not None:
fig.savefig(save_name, bbox_inches='tight', dpi=300, pad_inches=0.0)
plt.show()
plt.close(fig)
else:
plt.show()


def weighted_percentile(samples: np.ndarray, log_weights: np.ndarray,
Expand Down
10 changes: 2 additions & 8 deletions src/jaxns/samplers/uni_slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,21 +258,15 @@ def body(carry: Carry) -> Carry:
grad = grad_fn(carry.point_U)
num_likelihood_evaluations += jnp.ones_like(num_likelihood_evaluations)
grad_norm = jnp.linalg.norm(grad)
grad_mask = jnp.bitwise_or(jnp.equal(grad_norm, jnp.zeros_like(grad_norm)), ~jnp.isfinite(grad_norm))
grad_mask = jnp.bitwise_or(grad_norm < jnp.asarray(1e-10, grad_norm.dtype), ~jnp.isfinite(grad_norm))
grad = grad / grad_norm

reflect_direction = direction - 2 * tree_dot(direction, grad) * grad
reflect_direction /= jnp.linalg.norm(reflect_direction)

random_direction = _sample_direction(after_key1, direction.size)

choose_dir = jax.random.randint(after_key2, shape=(), minval=0, maxval=2)
direction = jnp.where(
choose_dir == 0,
reflect_direction,
random_direction
)
direction = jnp.where(grad_mask, random_direction, direction)
direction = jnp.where(grad_mask, random_direction, reflect_direction)
else:
# Randomly choose a new direction
direction = _sample_direction(after_key, direction.size)
Expand Down
4 changes: 3 additions & 1 deletion src/jaxns/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,9 @@ def summary(results: NestedSamplerResults, with_parametrised: bool = False, f_ob
main_s = []

def _print(s):
print(s)
if f_obj is None:
# It goes to file instead
print(s)
main_s.append(s)

def _round(v, uncert_v):
Expand Down

0 comments on commit 9098db8

Please sign in to comment.