diff --git a/README.md b/README.md index f68b1b6..dfec6dd 100644 --- a/README.md +++ b/README.md @@ -359,6 +359,8 @@ before importing JAXNS. # Change Log +13 Nov, 2024 -- JAXNS 2.6.6 released. Minor improvements to plotting. + 9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour `devices`. diff --git a/benchmarks/gh136/Gradient_guided_vs_baseline_D16_v2.6.5.png b/benchmarks/gh136/Gradient_guided_vs_baseline_D16_v2.6.5.png new file mode 100644 index 0000000..6163421 Binary files /dev/null and b/benchmarks/gh136/Gradient_guided_vs_baseline_D16_v2.6.5.png differ diff --git a/benchmarks/gh136/Gradient_guided_vs_baseline_D8_v2.6.5.png b/benchmarks/gh136/Gradient_guided_vs_baseline_D8_v2.6.5.png new file mode 100644 index 0000000..993cece Binary files /dev/null and b/benchmarks/gh136/Gradient_guided_vs_baseline_D8_v2.6.5.png differ diff --git a/benchmarks/gh136/main.py b/benchmarks/gh136/main.py index c39c198..30555fb 100644 --- a/benchmarks/gh136/main.py +++ b/benchmarks/gh136/main.py @@ -79,8 +79,8 @@ def get_data(ndims): def main(): jaxns_version = pkg_resources.get_distribution("jaxns").version - m = 3 - d = 32 + m = 90 + d = 16 data = get_data(d) @@ -88,14 +88,16 @@ def main(): # Row 2: Plot H error for gradient guided vs baseline for different s, with errorbars # Row 3: Plot time taken for gradient guided vs baseline for different s, with errorbars - s_array = [10, 20, 30, 40, 80, 120] + s_array = [0.5, 1, 2, 3, 4, 5] run_model_baseline_aot_array = [ - jax.jit(build_run_model(num_slices=s, gradient_guided=False, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for + jax.jit(build_run_model(num_slices=int(s * d), gradient_guided=False, ndims=d)).lower(jax.random.PRNGKey(0), + *data).compile() for s in s_array] run_model_gg_aot_array = [ - jax.jit(build_run_model(num_slices=s, gradient_guided=True, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for s + jax.jit(build_run_model(num_slices=int(s * d), gradient_guided=True, ndims=d)).lower(jax.random.PRNGKey(0), + *data).compile() for s in s_array] @@ -158,7 +160,7 @@ def main(): axs[2].fill_between(s_array, dt_mean[:, 1] - dt_std[:, 1], dt_mean[:, 1] + dt_std[:, 1], color='r', alpha=0.2) axs[2].set_ylabel("Time taken") axs[2].legend() - axs[2].set_xlabel(r"number of slices") + axs[2].set_xlabel(r"s, slices per dim") axs[0].set_title(f"Gradient guided vs baseline, D={d}, v{jaxns_version}") diff --git a/docs/conf.py b/docs/conf.py index 7245a30..4080030 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ project = "jaxns" copyright = "2024, Joshua G. Albert" author = "Joshua G. Albert" -release = "2.6.5" +release = "2.6.6" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 10cc738..2ee32b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "jaxns" -version = "2.6.5" +version = "2.6.6" description = "Nested Sampling in JAX" readme = "README.md" requires-python = ">=3.9" diff --git a/src/jaxns/plotting.py b/src/jaxns/plotting.py index 84f9c62..f6bea6f 100644 --- a/src/jaxns/plotting.py +++ b/src/jaxns/plotting.py @@ -89,7 +89,9 @@ def plot_diagnostics(results: NestedSamplerResults, save_name=None): axs[5].set_xlabel(r'$- \log X$') if save_name is not None: fig.savefig(save_name, bbox_inches='tight', dpi=300, pad_inches=0.0) - plt.show() + plt.close(fig) + else: + plt.show() def plot_cornerplot(results: NestedSamplerResults, variables: Optional[List[str]] = None, @@ -279,7 +281,9 @@ def plot_cornerplot(results: NestedSamplerResults, variables: Optional[List[str] # Save the figure if save_name is not None: fig.savefig(save_name, bbox_inches='tight', dpi=300, pad_inches=0.0) - plt.show() + plt.close(fig) + else: + plt.show() def weighted_percentile(samples: np.ndarray, log_weights: np.ndarray, diff --git a/src/jaxns/samplers/uni_slice_sampler.py b/src/jaxns/samplers/uni_slice_sampler.py index 7233de1..30622d5 100644 --- a/src/jaxns/samplers/uni_slice_sampler.py +++ b/src/jaxns/samplers/uni_slice_sampler.py @@ -258,7 +258,7 @@ def body(carry: Carry) -> Carry: grad = grad_fn(carry.point_U) num_likelihood_evaluations += jnp.ones_like(num_likelihood_evaluations) grad_norm = jnp.linalg.norm(grad) - grad_mask = jnp.bitwise_or(jnp.equal(grad_norm, jnp.zeros_like(grad_norm)), ~jnp.isfinite(grad_norm)) + grad_mask = jnp.bitwise_or(grad_norm < jnp.asarray(1e-10, grad_norm.dtype), ~jnp.isfinite(grad_norm)) grad = grad / grad_norm reflect_direction = direction - 2 * tree_dot(direction, grad) * grad @@ -266,13 +266,7 @@ def body(carry: Carry) -> Carry: random_direction = _sample_direction(after_key1, direction.size) - choose_dir = jax.random.randint(after_key2, shape=(), minval=0, maxval=2) - direction = jnp.where( - choose_dir == 0, - reflect_direction, - random_direction - ) - direction = jnp.where(grad_mask, random_direction, direction) + direction = jnp.where(grad_mask, random_direction, reflect_direction) else: # Randomly choose a new direction direction = _sample_direction(after_key, direction.size) diff --git a/src/jaxns/utils.py b/src/jaxns/utils.py index 855418c..535919e 100644 --- a/src/jaxns/utils.py +++ b/src/jaxns/utils.py @@ -321,7 +321,9 @@ def summary(results: NestedSamplerResults, with_parametrised: bool = False, f_ob main_s = [] def _print(s): - print(s) + if f_obj is None: + # It goes to file instead + print(s) main_s.append(s) def _round(v, uncert_v):