Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #214

Merged
merged 13 commits into from
Dec 7, 2024
7 changes: 2 additions & 5 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install -r requirements.txt
pip install -r requirements-tests.txt
pip install .
pip install .[tests]
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -39,4 +36,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
pytest -s
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ before importing JAXNS.

# Change Log

7 Dec, 2024 -- JAXNS 2.6.7 released. Fix pip dependencies install.

13 Nov, 2024 -- JAXNS 2.6.6 released. Minor improvements to plotting.

9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2024, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.6.6"
release = "2.6.7"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "jaxns"
version = "2.6.6"
version = "2.6.7"
description = "Nested Sampling in JAX"
readme = "README.md"
requires-python = ">=3.9"
Expand All @@ -19,10 +19,7 @@ classifiers = [
"Operating System :: OS Independent"
]
urls = { "Homepage" = "https://github.com/joshuaalbert/jaxns" }

[project.optional-dependencies]
# Define the extras here; they will be loaded dynamically from setup.py
notebooks = [] # Placeholders; extras will load from setup.py
dynamic = ["dependencies", "optional-dependencies"]

[tool.setuptools]
include-package-data = true
Expand Down
3 changes: 2 additions & 1 deletion requirements-tests.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
scikit-learn
networkx
psutil
pytest
pytest
flake8
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def load_requirements(file_name):
install_requires=load_requirements("requirements.txt"),
extras_require={
"examples": load_requirements("requirements-examples.txt"),
},
tests_require=load_requirements("requirements-tests.txt"),
"tests": load_requirements("requirements-tests.txt"),
}
)
18 changes: 16 additions & 2 deletions src/jaxns/framework/special_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"Poisson",
"UnnormalisedDirichlet",
"Empirical",
"TruncationWrapper"
"TruncationWrapper",
"ExplicitDensityPrior",
]


Expand Down Expand Up @@ -72,6 +73,7 @@ def _quantile(self, U):
sample = jnp.less(U, probs)
return sample.astype(self.dtype)


class Beta(SpecialPrior):
def __init__(self, *, concentration0=None, concentration1=None, name: Optional[str] = None):
super(Beta, self).__init__(name=name)
Expand Down Expand Up @@ -443,7 +445,8 @@ class Empirical(SpecialPrior):
Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.
"""

def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[str] = None):
def __init__(self, *, samples: jax.Array, support_min: Optional[FloatArray] = None,
support_max: Optional[FloatArray] = None, resolution: int = 100, name: Optional[str] = None):
super(Empirical, self).__init__(name=name)
if len(np.shape(samples)) < 1:
raise ValueError("Samples must have at least one dimension")
Expand All @@ -452,6 +455,17 @@ def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[
if resolution < 1:
raise ValueError("Resolution must be at least 1")
samples = jnp.asarray(samples)
# Add 1 point for each support endpoint
endpoints = []
if support_min is not None:
endpoints.append(support_min)
if support_max is not None:
endpoints.append(support_max)
if len(endpoints) > 0:
samples = jnp.concatenate([samples, jnp.asarray(endpoints)])

resolution = min(resolution, len(samples) - 1)

self._q = jnp.linspace(0., 100., resolution + 1)
self._percentiles = jnp.reshape(jnp.percentile(samples, self._q, axis=-1), (resolution + 1, -1))

Expand Down
8 changes: 4 additions & 4 deletions src/jaxns/framework/tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def test_forced_identifiability():


def test_empirical():
samples = jax.random.normal(jax.random.PRNGKey(42), shape=(5, 2000), dtype=mp_policy.measure_dtype)
samples = jax.random.normal(jax.random.PRNGKey(42), shape=(2000,), dtype=mp_policy.measure_dtype)
prior = Empirical(samples=samples, resolution=100, name='x')
assert prior._percentiles.shape == (101, 5)
assert prior._percentiles.shape == (101, 1)

x = prior.forward(jnp.ones(prior.base_shape, mp_policy.measure_dtype))
assert x.shape == (5,)
assert x.shape == ()
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
x = prior.forward(jnp.zeros(prior.base_shape, mp_policy.measure_dtype))
assert x.shape == (5,)
assert x.shape == ()
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))

x = prior.forward(0.5 * jnp.ones(prior.base_shape, mp_policy.measure_dtype))
Expand Down
6 changes: 3 additions & 3 deletions src/jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class NestedSampler:
max_samples: Optional[Union[int, float]] = None
num_live_points: Optional[int] = None
num_slices: Optional[int] = None
s: Optional[int] = None
s: Optional[Union[int, float]] = None
k: Optional[int] = None
c: Optional[int] = None
devices: Optional[List[xla_client.Device]] = None
Expand All @@ -70,9 +70,9 @@ def __post_init__(self):
# Determine number of slices per acceptance
if self.num_slices is None:
if self.difficult_model:
self.s = 10 if self.s is None else int(self.s)
self.s = 10 if self.s is None else float(self.s)
else:
self.s = 5 if self.s is None else int(self.s)
self.s = 5 if self.s is None else float(self.s)
if self.s <= 0:
raise ValueError(f"Expected s > 0, got s={self.s}")
self.num_slices = self.model.U_ndims * self.s
Expand Down
Loading