From f20e647baec5706b3a992c2e95c6ec4f0dfd726c Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 01:40:20 +0100 Subject: [PATCH 1/9] * update default s to be fracational --- src/jaxns/public.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jaxns/public.py b/src/jaxns/public.py index 215ba10..b5a08fd 100644 --- a/src/jaxns/public.py +++ b/src/jaxns/public.py @@ -55,7 +55,7 @@ class NestedSampler: max_samples: Optional[Union[int, float]] = None num_live_points: Optional[int] = None num_slices: Optional[int] = None - s: Optional[int] = None + s: Optional[Union[int, float]] = None k: Optional[int] = None c: Optional[int] = None devices: Optional[List[xla_client.Device]] = None @@ -70,9 +70,9 @@ def __post_init__(self): # Determine number of slices per acceptance if self.num_slices is None: if self.difficult_model: - self.s = 10 if self.s is None else int(self.s) + self.s = 10 if self.s is None else float(self.s) else: - self.s = 5 if self.s is None else int(self.s) + self.s = 5 if self.s is None else float(self.s) if self.s <= 0: raise ValueError(f"Expected s > 0, got s={self.s}") self.num_slices = self.model.U_ndims * self.s From aaf4aa795903457374de93fe8aab03624eac5686 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 01:40:45 +0100 Subject: [PATCH 2/9] * fix #212 --- pyproject.toml | 4 +++- setup.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ee32b0..0de3275 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,12 @@ classifiers = [ "Operating System :: OS Independent" ] urls = { "Homepage" = "https://github.com/joshuaalbert/jaxns" } +dynamic = ["dependencies"] [project.optional-dependencies] # Define the extras here; they will be loaded dynamically from setup.py -notebooks = [] # Placeholders; extras will load from setup.py +examples = [] # Placeholders; extras will load from setup.py +tests = [] # Placeholders; extras will load from setup.py [tool.setuptools] include-package-data = true diff --git a/setup.py b/setup.py index 7d40cf7..d0757d2 100755 --- a/setup.py +++ b/setup.py @@ -12,6 +12,6 @@ def load_requirements(file_name): install_requires=load_requirements("requirements.txt"), extras_require={ "examples": load_requirements("requirements-examples.txt"), - }, - tests_require=load_requirements("requirements-tests.txt"), + "tests": load_requirements("requirements-tests.txt"), + } ) From 1840973973477bc2bf9cd839ce12d68941124fdc Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 01:41:13 +0100 Subject: [PATCH 3/9] * add endpoints to empirical prior --- src/jaxns/framework/special_priors.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/jaxns/framework/special_priors.py b/src/jaxns/framework/special_priors.py index 2f0104d..a3b4287 100644 --- a/src/jaxns/framework/special_priors.py +++ b/src/jaxns/framework/special_priors.py @@ -25,7 +25,8 @@ "Poisson", "UnnormalisedDirichlet", "Empirical", - "TruncationWrapper" + "TruncationWrapper", + "ExplicitDensityPrior", ] @@ -72,6 +73,7 @@ def _quantile(self, U): sample = jnp.less(U, probs) return sample.astype(self.dtype) + class Beta(SpecialPrior): def __init__(self, *, concentration0=None, concentration1=None, name: Optional[str] = None): super(Beta, self).__init__(name=name) @@ -443,7 +445,8 @@ class Empirical(SpecialPrior): Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension. """ - def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[str] = None): + def __init__(self, *, samples: jax.Array, support_min: FloatArray | None = None, + support_max: FloatArray | None = None, resolution: int = 100, name: Optional[str] = None): super(Empirical, self).__init__(name=name) if len(np.shape(samples)) < 1: raise ValueError("Samples must have at least one dimension") @@ -452,6 +455,17 @@ def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[ if resolution < 1: raise ValueError("Resolution must be at least 1") samples = jnp.asarray(samples) + # Add 1 point for each support endpoint + endpoints = [] + if support_min is not None: + endpoints.append(support_min) + if support_max is not None: + endpoints.append(support_max) + if len(endpoints) > 0: + samples = jnp.concatenate([samples, jnp.asarray(endpoints)]) + + resolution = min(resolution, len(samples) - 1) + self._q = jnp.linspace(0., 100., resolution + 1) self._percentiles = jnp.reshape(jnp.percentile(samples, self._q, axis=-1), (resolution + 1, -1)) From b9d674132e3e34e8e585c4c1a8d064d267a6aa34 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 01:45:51 +0100 Subject: [PATCH 4/9] * Bump 2.6.7 --- README.md | 2 ++ docs/conf.py | 2 +- pyproject.toml | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dfec6dd..2307e0a 100644 --- a/README.md +++ b/README.md @@ -359,6 +359,8 @@ before importing JAXNS. # Change Log +7 Dec, 2024 -- JAXNS 2.6.7 released. Fix pip dependencies install. + 13 Nov, 2024 -- JAXNS 2.6.6 released. Minor improvements to plotting. 9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour diff --git a/docs/conf.py b/docs/conf.py index 4080030..63036e0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ project = "jaxns" copyright = "2024, Joshua G. Albert" author = "Joshua G. Albert" -release = "2.6.6" +release = "2.6.7" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 0de3275..5c68a7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "jaxns" -version = "2.6.6" +version = "2.6.7" description = "Nested Sampling in JAX" readme = "README.md" requires-python = ">=3.9" From 9176f2eb8e9b7df3e9c09ff212602e2159e07e20 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 01:47:46 +0100 Subject: [PATCH 5/9] * fix annotations --- src/jaxns/framework/special_priors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxns/framework/special_priors.py b/src/jaxns/framework/special_priors.py index a3b4287..223b08a 100644 --- a/src/jaxns/framework/special_priors.py +++ b/src/jaxns/framework/special_priors.py @@ -445,8 +445,8 @@ class Empirical(SpecialPrior): Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension. """ - def __init__(self, *, samples: jax.Array, support_min: FloatArray | None = None, - support_max: FloatArray | None = None, resolution: int = 100, name: Optional[str] = None): + def __init__(self, *, samples: jax.Array, support_min: Optional[FloatArray] = None, + support_max: Optional[FloatArray] = None, resolution: int = 100, name: Optional[str] = None): super(Empirical, self).__init__(name=name) if len(np.shape(samples)) < 1: raise ValueError("Samples must have at least one dimension") From 09fef5c9259d06d1af25fccea4fbfae1fec3c700 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 01:51:31 +0100 Subject: [PATCH 6/9] * fix test for empirical as it's now 1D --- src/jaxns/framework/tests/test_prior.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxns/framework/tests/test_prior.py b/src/jaxns/framework/tests/test_prior.py index c509d7d..e3beea6 100644 --- a/src/jaxns/framework/tests/test_prior.py +++ b/src/jaxns/framework/tests/test_prior.py @@ -321,15 +321,15 @@ def test_forced_identifiability(): def test_empirical(): - samples = jax.random.normal(jax.random.PRNGKey(42), shape=(5, 2000), dtype=mp_policy.measure_dtype) + samples = jax.random.normal(jax.random.PRNGKey(42), shape=(2000,), dtype=mp_policy.measure_dtype) prior = Empirical(samples=samples, resolution=100, name='x') - assert prior._percentiles.shape == (101, 5) + assert prior._percentiles.shape == (101, 1) x = prior.forward(jnp.ones(prior.base_shape, mp_policy.measure_dtype)) - assert x.shape == (5,) + assert x.shape == () assert jnp.all(jnp.bitwise_not(jnp.isnan(x))) x = prior.forward(jnp.zeros(prior.base_shape, mp_policy.measure_dtype)) - assert x.shape == (5,) + assert x.shape == () assert jnp.all(jnp.bitwise_not(jnp.isnan(x))) x = prior.forward(0.5 * jnp.ones(prior.base_shape, mp_policy.measure_dtype)) From 08908169371fb898c5472b04fff7b068a806b4b1 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 14:09:36 +0100 Subject: [PATCH 7/9] * update tests --- .github/workflows/unittests.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 2b69c21..48cf0c5 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -27,10 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest - pip install -r requirements.txt - pip install -r requirements-tests.txt - pip install . + pip install .[tests] - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -39,4 +36,4 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - pytest + pytest -s From 322a189b333b3c583a3d38c8e316cced9f98f6de Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 14:11:47 +0100 Subject: [PATCH 8/9] * update flake8 --- requirements-tests.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-tests.txt b/requirements-tests.txt index f9de252..b232a85 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -1,4 +1,5 @@ scikit-learn networkx psutil -pytest \ No newline at end of file +pytest +flake8 \ No newline at end of file From ad83526fdd1e0cc23ef853453c6aa88385cb8490 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Sat, 7 Dec 2024 14:20:11 +0100 Subject: [PATCH 9/9] * make optional reqs dynamic --- pyproject.toml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c68a7c..15ad741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,7 @@ classifiers = [ "Operating System :: OS Independent" ] urls = { "Homepage" = "https://github.com/joshuaalbert/jaxns" } -dynamic = ["dependencies"] - -[project.optional-dependencies] -# Define the extras here; they will be loaded dynamically from setup.py -examples = [] # Placeholders; extras will load from setup.py -tests = [] # Placeholders; extras will load from setup.py +dynamic = ["dependencies", "optional-dependencies"] [tool.setuptools] include-package-data = true