From b1990bb8aa95cd41246db9aaaac4cf59836ef4db Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 9 Feb 2026 15:16:27 +0100 Subject: [PATCH] Simplify Prior tests --- pymc_extras/prior.py | 22 +- pyproject.toml | 1 - tests/test_prior.py | 1002 ++++++++++++++++++++---------------------- 3 files changed, 482 insertions(+), 543 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index d2385d518..b857a987f 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -1234,22 +1234,14 @@ class VariableNotFound(Exception): def _remove_random_variable(var: pt.TensorVariable) -> None: - if var.name is None: - raise ValueError("This isn't removable") - - name: str = var.name - + # This is brittle, as it doesn't rely on any official model API. + # Fix this by allowing `Prior.create_dist` instead model = pm.modelcontext(None) - for idx, free_rv in enumerate(model.free_RVs): - if var == free_rv: - index_to_remove = idx - break - else: - raise VariableNotFound(f"Variable {var.name!r} not found") - - var.name = None - model.free_RVs.pop(index_to_remove) - model.named_vars.pop(name) + model.rvs_to_initial_values.pop(var) + model.rvs_to_transforms.pop(var) + model.rvs_to_values.pop(var) + model.free_RVs.remove(var) + model.named_vars.pop(var.name) @dataclass diff --git a/pyproject.toml b/pyproject.toml index 211698f9e..d14aa33b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ complete = [ ] dev = [ "pytest>=6.0", - "pytest-mock", "dask[all]<2025.1.1", "blackjax>=0.12", "statsmodels", diff --git a/tests/test_prior.py b/tests/test_prior.py index 0ff142718..055845a52 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import NamedTuple import numpy as np import pymc as pm @@ -92,6 +91,25 @@ def test_missing_transform() -> None: Prior("Normal", transform="foo_bar") +def test_getattr() -> None: + assert pr.Normal() == Prior("Normal") + + +def test_import_directly() -> None: + try: + from pymc_extras.prior import Normal + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + assert Normal() == Prior("Normal") + + +def test_import_incorrect_directly() -> None: + match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'" + with pytest.raises(UnsupportedDistributionError, match=match): + from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401 + + def test_get_item() -> None: var = Prior("Normal", mu=0, sigma=1) @@ -240,21 +258,20 @@ def test_create_variable_multiple_times() -> None: assert fast_eval(model[f"{prefix}{suffix}"]).shape == dim -@pytest.fixture -def large_var() -> Prior: - mu = Prior( +def test_create_variable() -> None: + large_var = Prior( "Normal", - mu=Prior("Normal", mu=1), - sigma=Prior("HalfNormal"), - dims="channel", - centered=False, + mu=Prior( + "Normal", + mu=Prior("Normal", mu=1), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ), + sigma=Prior("HalfNormal", sigma=Prior("HalfNormal"), dims="geo"), + dims=("geo", "channel"), ) - sigma = Prior("HalfNormal", sigma=Prior("HalfNormal"), dims="geo") - return Prior("Normal", mu=mu, sigma=sigma, dims=("geo", "channel")) - - -def test_create_variable(large_var) -> None: coords = { "channel": ["a", "b", "c"], "geo": ["x", "y"], @@ -303,9 +320,21 @@ def test_transform() -> None: assert fast_eval(model[var_name]).shape == dim -def test_to_dict(large_var) -> None: - data = large_var.to_dict() +def test_to_dict() -> None: + large_var = Prior( + "Normal", + mu=Prior( + "Normal", + mu=Prior("Normal", mu=1), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ), + sigma=Prior("HalfNormal", sigma=Prior("HalfNormal"), dims="geo"), + dims=("geo", "channel"), + ) + data = large_var.to_dict() assert data == { "dist": "Normal", "kwargs": { @@ -338,6 +367,8 @@ def test_to_dict(large_var) -> None: "dims": ("geo", "channel"), } + assert Prior.from_dict(data) == large_var + def test_to_dict_numpy() -> None: var = Prior("Normal", mu=np.array([0, 10, 20]), dims="channel") @@ -350,10 +381,6 @@ def test_to_dict_numpy() -> None: } -def test_dict_round_trip(large_var) -> None: - assert Prior.from_dict(large_var.to_dict()) == large_var - - def test_constrain_with_transform_error() -> None: var = Prior("Normal", transform="sigmoid") @@ -361,16 +388,12 @@ def test_constrain_with_transform_error() -> None: var.constrain(lower=0, upper=1) -def test_constrain(mocker) -> None: +def test_constrain() -> None: var = Prior("Normal") - mocker.patch( - "preliz.maxent", - return_value=mocker.Mock(params_dict={"mu": 5, "sigma": 2}), - ) - - new_var = var.constrain(lower=0, upper=1) - assert new_var == Prior("Normal", mu=5, sigma=2) + new_var = var.constrain(lower=0, upper=1, mass=0.9545) + np.testing.assert_allclose(new_var.parameters["mu"], 0.5, rtol=1e-4) + np.testing.assert_allclose(new_var.parameters["sigma"], 0.25, rtol=1e-4) def test_dims_change() -> None: @@ -400,9 +423,9 @@ def test_deepcopy() -> None: assert new_priors["alpha"].dims == () -@pytest.fixture -def mmm_default_model_config(): - return { +def test_backwards_compat() -> None: + """Make sure functionality is compatible with use in PyMC-marketing, where Prior objects originated from.""" + mmm_default_model_config = { "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, "likelihood": { "dist": "Normal", @@ -422,8 +445,6 @@ def mmm_default_model_config(): }, } - -def test_backwards_compat(mmm_default_model_config) -> None: result = {param: Prior.from_dict(value) for param, value in mmm_default_model_config.items()} assert result == { "intercept": Prior("Normal", mu=0, sigma=2), @@ -660,546 +681,473 @@ def test_zsn_non_centered() -> None: pytest.fail(f"Unexpected exception: {e}") -class Arbitrary: - def __init__(self, dims: str | tuple[str, ...]) -> None: - self.dims = dims - - def create_variable(self, name: str): - return pm.Normal(name, dims=self.dims) - - -class ArbitraryWithoutName: - def __init__(self, dims: str | tuple[str, ...]) -> None: - self.dims = dims +class TestCustomClass: + def test_sample_prior_arbitrary_no_name(self) -> None: + class ArbitraryWithoutName: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims - def create_variable(self, name: str): - with pm.Model(name=name): - location = pm.Normal("location", dims=self.dims) - scale = pm.HalfNormal("scale", dims=self.dims) + def create_variable(self, name: str): + with pm.Model(name=name): + location = pm.Normal("location", dims=self.dims) + scale = pm.HalfNormal("scale", dims=self.dims) - return pm.Normal("standard_normal") * scale + location + return pm.Normal("standard_normal") * scale + location + var = ArbitraryWithoutName(dims="channel") -def test_sample_prior_arbitrary() -> None: - var = Arbitrary(dims="channel") - - prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) - - assert isinstance(prior, xr.Dataset) - - -def test_sample_prior_arbitrary_no_name() -> None: - var = ArbitraryWithoutName(dims="channel") - - prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) - - assert isinstance(prior, xr.Dataset) - assert "variable" not in prior - - prior_with = sample_prior( - var, - coords={"channel": ["A", "B", "C"]}, - draws=25, - wrap=True, - ) - - assert isinstance(prior_with, xr.Dataset) - assert "variable" in prior_with - - -def test_create_prior_with_arbitrary() -> None: - dist = Prior( - "Normal", - mu=Arbitrary(dims=("channel",)), - sigma=1, - dims=("channel", "geo"), - ) + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) - coords = { - "channel": ["C1", "C2", "C3"], - "geo": ["G1", "G2"], - } - with pm.Model(coords=coords) as model: - dist.create_variable("var") + assert isinstance(prior, xr.Dataset) + assert "variable" not in prior - assert "var_mu" in model - var_mu = model["var_mu"] + prior_with = sample_prior( + var, + coords={"channel": ["A", "B", "C"]}, + draws=25, + wrap=True, + ) - assert fast_eval(var_mu).shape == (len(coords["channel"]),) + assert isinstance(prior_with, xr.Dataset) + assert "variable" in prior_with + def test_arbitrary_class(self) -> None: + class Arbitrary: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims -def test_censored_is_variable_factory() -> None: - normal = Prior("Normal") - censored_normal = Censored(normal, lower=0) + def create_variable(self, name: str): + return pm.Normal(name, dims=self.dims) - assert isinstance(censored_normal, VariableFactory) + prior = Prior( + "Normal", + mu=Arbitrary(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + coords = { + "channel": ["C1", "C2", "C3"], + "geo": ["G1", "G2"], + } + with pm.Model(coords=coords) as model: + prior.create_variable("var") -@pytest.mark.parametrize( - "dims, expected_dims", - [ - ("channel", ("channel",)), - (("channel", "geo"), ("channel", "geo")), - ], - ids=["string", "tuple"], -) -def test_censored_dims_from_distribution(dims, expected_dims) -> None: - normal = Prior("Normal", dims=dims) - censored_normal = Censored(normal, lower=0) + assert "var_mu" in model + var_mu = model["var_mu"] + assert fast_eval(var_mu).shape == (len(coords["channel"]),) - assert censored_normal.dims == expected_dims + def test_arbitrary_class_sample_prior(self) -> None: + class Arbitrary: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + def create_variable(self, name: str): + return pm.Normal(name, dims=self.dims) -def test_censored_variables_created() -> None: - normal = Prior("Normal", mu=Prior("Normal"), dims="dim") - censored_normal = Censored(normal, lower=0) + var = Arbitrary(dims="channel") + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + assert isinstance(prior, xr.Dataset) - coords = {"dim": range(3)} - with pm.Model(coords=coords) as model: - censored_normal.create_variable("var") + def test_arbitrary_serialization(self) -> None: + class ArbitrarySerializable: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims - var_names = ["var", "var_mu"] - assert set(var.name for var in model.unobserved_RVs) == set(var_names) - dims = [(3,), ()] - for var_name, dim in zip(var_names, dims, strict=False): - assert fast_eval(model[var_name]).shape == dim + def create_variable(self, name: str): + return pm.Normal(name, dims=self.dims) + def to_dict(self): + return {"dims": self.dims} -def test_censored_sample_prior() -> None: - normal = Prior("Normal", dims="channel") - censored_normal = Censored(normal, lower=0) + arbitrary_serialized_data = {"dims": ("channel",)} - coords = {"channel": ["A", "B", "C"]} - prior = censored_normal.sample_prior(coords=coords, draws=25) + register_deserialization( + lambda data: isinstance(data, dict) and data.keys() == {"dims"}, + lambda data: ArbitrarySerializable(**data), + ) - assert isinstance(prior, xr.Dataset) - assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + dist = Prior( + "Normal", + mu=ArbitrarySerializable(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + data = { + "dist": "Normal", + "kwargs": { + "mu": arbitrary_serialized_data, + "sigma": 1, + }, + "dims": ("channel", "geo"), + } -def test_censored_to_graph() -> None: - normal = Prior("Normal", dims="channel") - censored_normal = Censored(normal, lower=0) + assert dist.to_dict() == data - G = censored_normal.to_graph() - assert isinstance(G, Digraph) + dist_again = deserialize(data) + assert isinstance(dist_again["mu"], ArbitrarySerializable) + assert dist_again["mu"].dims == ("channel",) + DESERIALIZERS.pop() -def test_censored_likelihood_variable() -> None: - normal = Prior("Normal", sigma=Prior("HalfNormal"), dims="channel") - censored_normal = Censored(normal, lower=0) - coords = {"channel": range(3)} - with pm.Model(coords=coords) as model: - mu = pm.Normal("mu") - variable = censored_normal.create_likelihood_variable( - name="likelihood", - mu=mu, - observed=[1, 2, 3], - ) +class TestScaled: + def test_scaled_initializes_correctly(self) -> None: + """Test that the Scaled class initializes correctly.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) - assert isinstance(variable, pt.TensorVariable) - assert model.observed_RVs == [variable] - assert "likelihood_sigma" in model + assert scaled.dist == normal + assert scaled.factor == 2.0 + def test_scaled_dims_property(self) -> None: + """Test that the dims property returns the dimensions of the underlying distribution.""" + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + scaled = Scaled(normal, factor=2.0) -def test_censored_likelihood_unsupported_distribution() -> None: - cauchy = Prior("Cauchy") - censored_cauchy = Censored(cauchy, lower=0) + assert scaled.dims == ("channel",) - with pm.Model(): - mu = pm.Normal("mu") - with pytest.raises(UnsupportedDistributionError): - censored_cauchy.create_likelihood_variable( + # Test with multiple dimensions + normal.dims = ("channel", "geo") + assert scaled.dims == ("channel", "geo") + + def test_scaled_create_variable(self) -> None: + """Test that the create_variable method properly scales the variable.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + with pm.Model() as model: + scaled_var = scaled.create_variable("scaled_var") + + # Check that both the scaled and unscaled variables exist + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + + # The deterministic node should be the scaled variable + assert model["scaled_var"] == scaled_var + + def test_scaled_creates_correct_dimensions(self) -> None: + """Test that the scaled variable has the correct dimensions.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords): + scaled_var = scaled.create_variable("scaled_var") + + # Check that the scaled variable has the correct dimensions + assert fast_eval(scaled_var).shape == (3,) + + def test_scaled_applies_factor(self) -> None: + """Test that the scaling factor is correctly applied.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = 3.5 + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify scaling + prior = sample_prior(scaled, draws=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * factor) + + def test_scaled_with_tensor_factor(self) -> None: + """Test that the Scaled class works with a tensor factor.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = pt.as_tensor_variable(2.5) + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify tensor scaling + prior = sample_prior(scaled, draws=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * 2.5) + + def test_scaled_with_hierarchical_prior(self) -> None: + """Test that the Scaled class works with hierarchical priors.""" + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords) as model: + scaled.create_variable("scaled_var") + + # Check that all necessary variables were created + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + assert "scaled_var_unscaled_mu" in model + assert "scaled_var_unscaled_sigma" in model + + def test_scaled_sample_prior(self) -> None: + """Test that sample_prior works with the Scaled class.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + prior = sample_prior(scaled, coords=coords, draws=25, name="scaled_var") + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + assert "scaled_var" in prior + assert "scaled_var_unscaled" in prior + + +class TestCensored: + def test_censored_is_variable_factory( + self, + ) -> None: + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + assert isinstance(censored_normal, VariableFactory) + + def test_deserialize_censored(self) -> None: + data = { + "class": "Censored", + "data": { + "dist": { + "dist": "Normal", + }, + "lower": 0, + "upper": float("inf"), + }, + } + + instance = deserialize(data) + assert isinstance(instance, Censored) + assert isinstance(instance.distribution, Prior) + assert instance.lower == 0 + assert instance.upper == float("inf") + + @pytest.mark.parametrize( + "dims, expected_dims", + [ + ("channel", ("channel",)), + (("channel", "geo"), ("channel", "geo")), + ], + ids=["string", "tuple"], + ) + def test_censored_dims_from_distribution(self, dims, expected_dims) -> None: + normal = Prior("Normal", dims=dims) + censored_normal = Censored(normal, lower=0) + + assert censored_normal.dims == expected_dims + + def test_censored_variables_created( + self, + ) -> None: + normal = Prior("Normal", mu=Prior("Normal"), dims="dim") + censored_normal = Censored(normal, lower=0) + + coords = {"dim": range(3)} + with pm.Model(coords=coords) as model: + censored_normal.create_variable("var") + + var_names = ["var", "var_mu"] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [(3,), ()] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + def test_censored_sample_prior( + self, + ) -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": ["A", "B", "C"]} + prior = censored_normal.sample_prior(coords=coords, draws=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + def test_censored_to_graph( + self, + ) -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + G = censored_normal.to_graph() + assert isinstance(G, Digraph) + + def test_censored_likelihood_variable( + self, + ) -> None: + normal = Prior("Normal", sigma=Prior("HalfNormal"), dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu") + variable = censored_normal.create_likelihood_variable( name="likelihood", mu=mu, - observed=1, + observed=[1, 2, 3], ) - -def test_censored_likelihood_already_has_mu() -> None: - normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) - censored_normal = Censored(normal, lower=0) - - with pm.Model(): - mu = pm.Normal("mu") - with pytest.raises(MuAlreadyExistsError): - censored_normal.create_likelihood_variable( - name="likelihood", + assert isinstance(variable, pt.TensorVariable) + assert model.observed_RVs == [variable] + assert "likelihood_sigma" in model + + def test_censored_likelihood_unsupported_distribution( + self, + ) -> None: + cauchy = Prior("Cauchy") + censored_cauchy = Censored(cauchy, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + censored_cauchy.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + def test_censored_likelihood_already_has_mu( + self, + ) -> None: + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + censored_normal = Censored(normal, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(MuAlreadyExistsError): + censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + def test_censored_to_dict( + self, + ) -> None: + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + censored_normal = Censored(normal, lower=0) + + data = censored_normal.to_dict() + assert data == { + "class": "Censored", + "data": {"dist": normal.to_dict(), "lower": 0, "upper": float("inf")}, + } + + @pytest.mark.parametrize( + "mu", + [ + 0, + np.arange(10), + ], + ids=["scalar", "vector"], + ) + def test_censored_logp(self, mu) -> None: + n_points = 10 + observed = np.zeros(n_points) + coords = {"idx": range(n_points)} + with pm.Model(coords=coords) as model: + normal = Prior("Normal", dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + observed=observed, mu=mu, - observed=1, ) - - -def test_censored_to_dict() -> None: - normal = Prior("Normal", mu=0, sigma=1, dims="channel") - censored_normal = Censored(normal, lower=0) - - data = censored_normal.to_dict() - assert data == { - "class": "Censored", - "data": {"dist": normal.to_dict(), "lower": 0, "upper": float("inf")}, - } - - -def test_deserialize_censored() -> None: - data = { - "class": "Censored", - "data": { - "dist": { - "dist": "Normal", + logp = model.compile_logp() + + with pm.Model() as expected_model: + pm.Censored( + "censored_normal", + pm.Normal.dist(mu=mu, sigma=1, shape=n_points), + lower=0, + upper=np.inf, + observed=observed, + ) + expected_logp = expected_model.compile_logp() + + point = {} + np.testing.assert_allclose(logp(point), expected_logp(point)) + + def test_censored_with_tensor_variable(self) -> None: + normal = Prior("Normal", dims="channel") + lower = pt.as_tensor_variable([0, 1, 2]) + censored_normal = Censored(normal, lower=lower) + + assert censored_normal.to_dict() == { + "class": "Censored", + "data": { + "dist": normal.to_dict(), + "lower": [0, 1, 2], + "upper": float("inf"), }, - "lower": 0, - "upper": float("inf"), - }, - } - - instance = deserialize(data) - assert isinstance(instance, Censored) - assert isinstance(instance.distribution, Prior) - assert instance.lower == 0 - assert instance.upper == float("inf") - - -class ArbitrarySerializable(Arbitrary): - def to_dict(self): - return {"dims": self.dims} - - -@pytest.fixture -def arbitrary_serialized_data() -> dict: - return {"dims": ("channel",)} - - -def test_create_prior_with_arbitrary_serializable(arbitrary_serialized_data) -> None: - dist = Prior( - "Normal", - mu=ArbitrarySerializable(dims=("channel",)), - sigma=1, - dims=("channel", "geo"), - ) - - assert dist.to_dict() == { - "dist": "Normal", - "kwargs": { - "mu": arbitrary_serialized_data, - "sigma": 1, - }, - "dims": ("channel", "geo"), - } - - -@pytest.fixture -def register_arbitrary_deserialization(): - register_deserialization( - lambda data: isinstance(data, dict) and data.keys() == {"dims"}, - lambda data: ArbitrarySerializable(**data), - ) - - yield - - DESERIALIZERS.pop() - - -def test_deserialize_arbitrary_within_prior( - arbitrary_serialized_data, - register_arbitrary_deserialization, -) -> None: - data = { - "dist": "Normal", - "kwargs": { - "mu": arbitrary_serialized_data, - "sigma": 1, - }, - "dims": ("channel", "geo"), - } - - dist = deserialize(data) - assert isinstance(dist["mu"], ArbitrarySerializable) - assert dist["mu"].dims == ("channel",) - - -def test_censored_with_tensor_variable() -> None: - normal = Prior("Normal", dims="channel") - lower = pt.as_tensor_variable([0, 1, 2]) - censored_normal = Censored(normal, lower=lower) - - assert censored_normal.to_dict() == { - "class": "Censored", - "data": { - "dist": normal.to_dict(), - "lower": [0, 1, 2], - "upper": float("inf"), - }, - } - - -def test_censored_dims_setter() -> None: - normal = Prior("Normal", dims="channel") - censored_normal = Censored(normal, lower=0) - censored_normal.dims = "date" - assert normal.dims == ("date",) - - -class ModelData(NamedTuple): - mu: float - observed: list[float] - - -@pytest.fixture(scope="session") -def model_data() -> ModelData: - return ModelData(mu=0, observed=[0, 1, 2, 3, 4]) - - -@pytest.fixture(scope="session") -def normal_model_with_censored_API(model_data) -> pm.Model: - coords = {"idx": range(len(model_data.observed))} - with pm.Model(coords=coords) as model: - sigma = Prior("HalfNormal") - normal = Prior("Normal", sigma=sigma, dims="idx") - Censored(normal, lower=0).create_likelihood_variable( - "censored_normal", - mu=model_data.mu, - observed=model_data.observed, - ) - - return model - - -@pytest.fixture(scope="session") -def normal_model_with_censored_logp(normal_model_with_censored_API): - return normal_model_with_censored_API.compile_logp() - - -@pytest.fixture(scope="session") -def expected_normal_model(model_data) -> pm.Model: - n_points = len(model_data.observed) - with pm.Model() as expected_model: - sigma = pm.HalfNormal("censored_normal_sigma") - normal = pm.Normal.dist(mu=model_data.mu, sigma=sigma, shape=n_points) - pm.Censored( - "censored_normal", - normal, - lower=0, - upper=np.inf, - observed=model_data.observed, - ) - - return expected_model - - -@pytest.fixture(scope="session") -def expected_normal_model_logp(expected_normal_model): - return expected_normal_model.compile_logp() - - -@pytest.mark.parametrize("sigma_log__", [-10, -5, -2.5, 0, 2.5, 5, 10]) -def test_censored_normal_logp( - sigma_log__, - normal_model_with_censored_logp, - expected_normal_model_logp, -) -> None: - points = {"censored_normal_sigma_log__": sigma_log__} - normal_model_logp = normal_model_with_censored_logp(points) - expected_model_logp = expected_normal_model_logp(points) - np.testing.assert_allclose(normal_model_logp, expected_model_logp) - - -@pytest.mark.parametrize( - "mu", - [ - 0, - np.arange(10), - ], - ids=["scalar", "vector"], -) -def test_censored_logp(mu) -> None: - n_points = 10 - observed = np.zeros(n_points) - coords = {"idx": range(n_points)} - with pm.Model(coords=coords) as model: - normal = Prior("Normal", dims="idx") - Censored(normal, lower=0).create_likelihood_variable( - "censored_normal", - observed=observed, - mu=mu, - ) - logp = model.compile_logp() - - with pm.Model() as expected_model: - pm.Censored( - "censored_normal", - pm.Normal.dist(mu=mu, sigma=1, shape=n_points), - lower=0, - upper=np.inf, - observed=observed, - ) - expected_logp = expected_model.compile_logp() - - point = {} - np.testing.assert_allclose(logp(point), expected_logp(point)) - - -def test_scaled_initializes_correctly() -> None: - """Test that the Scaled class initializes correctly.""" - normal = Prior("Normal", mu=0, sigma=1) - scaled = Scaled(normal, factor=2.0) - - assert scaled.dist == normal - assert scaled.factor == 2.0 - - -def test_scaled_dims_property() -> None: - """Test that the dims property returns the dimensions of the underlying distribution.""" - normal = Prior("Normal", mu=0, sigma=1, dims="channel") - scaled = Scaled(normal, factor=2.0) - - assert scaled.dims == ("channel",) - - # Test with multiple dimensions - normal.dims = ("channel", "geo") - assert scaled.dims == ("channel", "geo") - - -def test_scaled_create_variable() -> None: - """Test that the create_variable method properly scales the variable.""" - normal = Prior("Normal", mu=0, sigma=1) - scaled = Scaled(normal, factor=2.0) - - with pm.Model() as model: - scaled_var = scaled.create_variable("scaled_var") - - # Check that both the scaled and unscaled variables exist - assert "scaled_var" in model - assert "scaled_var_unscaled" in model - - # The deterministic node should be the scaled variable - assert model["scaled_var"] == scaled_var - - -def test_scaled_creates_correct_dimensions() -> None: - """Test that the scaled variable has the correct dimensions.""" - normal = Prior("Normal", dims="channel") - scaled = Scaled(normal, factor=2.0) - - coords = {"channel": ["A", "B", "C"]} - with pm.Model(coords=coords): - scaled_var = scaled.create_variable("scaled_var") - - # Check that the scaled variable has the correct dimensions - assert fast_eval(scaled_var).shape == (3,) - - -def test_scaled_applies_factor() -> None: - """Test that the scaling factor is correctly applied.""" - normal = Prior("Normal", mu=0, sigma=1) - factor = 3.5 - scaled = Scaled(normal, factor=factor) - - # Sample from prior to verify scaling - prior = sample_prior(scaled, draws=10, name="scaled_var") - df_prior = prior.to_dataframe() - - # Check that scaled values are original values times the factor - unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() - scaled_values = df_prior["scaled_var"].to_numpy() - np.testing.assert_allclose(scaled_values, unscaled_values * factor) - - -def test_scaled_with_tensor_factor() -> None: - """Test that the Scaled class works with a tensor factor.""" - normal = Prior("Normal", mu=0, sigma=1) - factor = pt.as_tensor_variable(2.5) - scaled = Scaled(normal, factor=factor) - - # Sample from prior to verify tensor scaling - prior = sample_prior(scaled, draws=10, name="scaled_var") - df_prior = prior.to_dataframe() - - # Check that scaled values are original values times the factor - unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() - scaled_values = df_prior["scaled_var"].to_numpy() - np.testing.assert_allclose(scaled_values, unscaled_values * 2.5) - - -def test_scaled_with_hierarchical_prior() -> None: - """Test that the Scaled class works with hierarchical priors.""" - normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel") - scaled = Scaled(normal, factor=2.0) - - coords = {"channel": ["A", "B", "C"]} - with pm.Model(coords=coords) as model: - scaled.create_variable("scaled_var") - - # Check that all necessary variables were created - assert "scaled_var" in model - assert "scaled_var_unscaled" in model - assert "scaled_var_unscaled_mu" in model - assert "scaled_var_unscaled_sigma" in model - - -def test_scaled_sample_prior() -> None: - """Test that sample_prior works with the Scaled class.""" - normal = Prior("Normal", dims="channel") - scaled = Scaled(normal, factor=2.0) - - coords = {"channel": ["A", "B", "C"]} - prior = sample_prior(scaled, coords=coords, draws=25, name="scaled_var") - - assert isinstance(prior, xr.Dataset) - assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} - assert "scaled_var" in prior - assert "scaled_var_unscaled" in prior - - -def test_getattr() -> None: - assert pr.Normal() == Prior("Normal") - - -def test_import_directly() -> None: - try: - from pymc_extras.prior import Normal - except Exception as e: - pytest.fail(f"Unexpected exception: {e}") - - assert Normal() == Prior("Normal") - - -def test_import_incorrect_directly() -> None: - match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'" - with pytest.raises(UnsupportedDistributionError, match=match): - from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401 - + } + + def test_censored_dims_setter(self) -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + censored_normal.dims = "date" + assert normal.dims == ("date",) + + def test_censored_normal(self) -> None: + coords = {"idx": range(5)} + observed = np.arange(5, dtype=float) + mu = np.pi + + with pm.Model(coords=coords) as model: + sigma = Prior("HalfNormal") + normal = Prior("Normal", sigma=sigma, dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + mu=mu, + observed=observed, + ) -@pytest.fixture -def alternative_prior_deserialize(): - def is_type(data): - return isinstance(data, dict) and "distribution" in data + with pm.Model(coords=coords) as expected_model: + sigma = pm.HalfNormal("censored_normal_sigma") + normal = pm.Normal.dist(mu=mu, sigma=sigma) + pm.Censored( + "censored_normal", + normal, + lower=0, + upper=np.inf, + observed=observed, + dims="idx", + ) - def deserialize(data): - return Prior(**data) + # This doesn't work because of no OpFromGraph equality impl + # assert equivalent_models(model, expected_model) - register_deserialization(is_type=is_type, deserialize=deserialize) + ip = model.initial_point() + np.testing.assert_allclose(model.compile_logp()(ip), expected_model.compile_logp()(ip)) - yield + def test_censored_with_alternative_class(self) -> None: + def is_type(data): + return isinstance(data, dict) and "distribution" in data - DESERIALIZERS.pop() + def deserialize_func(data): + return Prior(**data) + register_deserialization(is_type=is_type, deserialize=deserialize_func) -def test_censored_with_alternative(alternative_prior_deserialize) -> None: - data = { - "class": "Censored", - "data": { - "dist": { - "distribution": "Normal", + data = { + "class": "Censored", + "data": { + "dist": { + "distribution": "Normal", + }, + "lower": 0, + "upper": 10, }, - "lower": 0, - "upper": 10, - }, - } + } + + instance = deserialize(data) - instance = deserialize(data) + assert isinstance(instance, Censored) + assert instance.lower == 0 + assert instance.upper == 10 + assert instance.distribution == Prior("Normal") - assert isinstance(instance, Censored) - assert instance.lower == 0 - assert instance.upper == 10 - assert instance.distribution == Prior("Normal") + DESERIALIZERS.pop()