From b35001b346a0db15d48241d0c4cc5d1c102f1b89 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 13:18:19 -0400 Subject: [PATCH 01/14] add pymc-extras to environment --- environment.yml | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/environment.yml b/environment.yml index 02b7f920..2bc8ed20 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - seaborn>=0.11.2 - statsmodels - xarray>=v2022.11.0 + - pymc-extras>=0.2.7 diff --git a/pyproject.toml b/pyproject.toml index 29f86277..bcc4bc7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "seaborn>=0.11.2", "statsmodels", "xarray>=v2022.11.0", + "pymc-extras>=0.2.7", ] # List additional groups of dependencies here (e.g. development dependencies). Users From b7300e79db32abca5ef88eb08095f39f764db590 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 13:20:06 -0400 Subject: [PATCH 02/14] add default_priors and support for custom priors --- causalpy/pymc_models.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index ea380c1a..3ed4cac9 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -22,6 +22,7 @@ import pytensor.tensor as pt import xarray as xr from arviz import r2_score +from pymc_extras.prior import Prior from causalpy.utils import round_num @@ -68,7 +69,13 @@ class PyMCModel(pm.Model): Inference data... """ - def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): + default_priors: dict[str, Any] + + def __init__( + self, + sample_kwargs: Optional[Dict[str, Any]] = None, + priors: dict[str, Any] | None = None, + ): """ :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the :func:`pymc.sample` function. Defaults to an empty dictionary. @@ -77,6 +84,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): self.idata = None self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} + self.priors = {**self.default_priors, **(priors or {})} + def build_model(self, X, y, coords) -> None: """Build the model, must be implemented by subclass.""" raise NotImplementedError("This method must be implemented by a subclass") @@ -237,6 +246,11 @@ class LinearRegression(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -245,10 +259,9 @@ def build_model(self, X, y, coords): self.add_coords(coords) X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims="obs_ind") - beta = pm.Normal("beta", 0, 50, dims="coeffs") - sigma = pm.HalfNormal("sigma", 1) + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") - pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class WeightedSumFitter(PyMCModel): @@ -276,6 +289,10 @@ class WeightedSumFitter(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -286,9 +303,8 @@ def build_model(self, X, y, coords): X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y[:, 0], dims="obs_ind") beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") - sigma = pm.HalfNormal("sigma", 1) mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") - pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class InstrumentalVariableRegression(PyMCModel): @@ -477,13 +493,17 @@ class PropensityScore(PyMCModel): Inference... """ # noqa: W605 + default_priors = { + "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), + } + def build_model(self, X, t, coords): "Defines the PyMC propensity model" with self: self.add_coords(coords) X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) t_data = pm.Data("t", t.flatten(), dims="obs_ind") - b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") + b = self.priors["b"].create_variable("b") mu = pm.math.dot(X_data, b) p = pm.Deterministic("p", pm.math.invlogit(mu)) pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") From a60035e61d0b7653515721d83dad958d4d368ee2 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 13:52:22 -0400 Subject: [PATCH 03/14] get pymc_models tests to pass --- causalpy/pymc_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 3ed4cac9..6f15f0cf 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -69,7 +69,9 @@ class PyMCModel(pm.Model): Inference data... """ - default_priors: dict[str, Any] + @property + def default_priors(self): + return {} def __init__( self, @@ -248,7 +250,7 @@ class LinearRegression(PyMCModel): default_priors = { "beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), - "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), } def build_model(self, X, y, coords): From 367c9220b835e35b32c1b48c77c4e76f578f91e9 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 17:05:02 -0400 Subject: [PATCH 04/14] add dim to y_hat --- causalpy/pymc_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 6f15f0cf..812e9d70 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -292,7 +292,7 @@ class WeightedSumFitter(PyMCModel): """ # noqa: W605 default_priors = { - "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), } def build_model(self, X, y, coords): From a9f821c8e84a5edaea2615a0c60582a205056318 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Fri, 20 Jun 2025 11:31:53 +0100 Subject: [PATCH 05/14] fix for sigma -> y_hat_sigma --- causalpy/pymc_models.py | 6 +++--- docs/source/_static/interrogate_badge.svg | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 812e9d70..f95b6371 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -199,15 +199,15 @@ def print_row( coeffs = az.extract(self.idata.posterior, var_names="beta") # Determine the width of the longest label - max_label_length = max(len(name) for name in labels + ["sigma"]) + max_label_length = max(len(name) for name in labels + ["y_hat_sigma"]) for name in labels: coeff_samples = coeffs.sel(coeffs=name) print_row(max_label_length, name, coeff_samples, round_to) # Add coefficient for measurement std - coeff_samples = az.extract(self.idata.posterior, var_names="sigma") - name = "sigma" + coeff_samples = az.extract(self.idata.posterior, var_names="y_hat_sigma") + name = "y_hat_sigma" print_row(max_label_length, name, coeff_samples, round_to) diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 9975f47a..4a908d60 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 94.9% + interrogate: 94.5% @@ -12,8 +12,8 @@ interrogate interrogate - 94.9% - 94.9% + 94.5% + 94.5% From 91aee009f60506bdb2af55faef256b864d8bb483 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Fri, 20 Jun 2025 12:02:06 +0100 Subject: [PATCH 06/14] fix failing doctest --- causalpy/experiments/prepostnegd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py index beec847e..3ab18968 100644 --- a/causalpy/experiments/prepostnegd.py +++ b/causalpy/experiments/prepostnegd.py @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment): Intercept -0.5, 94% HDI [-1, 0.2] C(group)[T.1] 2, 94% HDI [2, 2] pre 1, 94% HDI [1, 1] - sigma 0.5, 94% HDI [0.5, 0.6] + y_hat_sigma 0.5, 94% HDI [0.5, 0.6] """ supports_ols = False From dc20e3e67edfd394910312c0edca0a592386c37d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 12 Jul 2025 17:21:33 -0400 Subject: [PATCH 07/14] add support for priors from data --- causalpy/pymc_models.py | 15 +++++++++++++-- docs/source/_static/interrogate_badge.svg | 6 +++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index f95b6371..4c4ceee4 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -73,6 +73,9 @@ class PyMCModel(pm.Model): def default_priors(self): return {} + def priors_from_data(self, X, y) -> Dict[str, Any]: + return {} + def __init__( self, sample_kwargs: Optional[Dict[str, Any]] = None, @@ -122,6 +125,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None: # sample_posterior_predictive() if provided in sample_kwargs. random_seed = self.sample_kwargs.get("random_seed", None) + self.priors = {**self.priors_from_data(X, y), **self.priors} + self.build_model(X, y, coords) with self: self.idata = pm.sample(**self.sample_kwargs) @@ -295,16 +300,22 @@ class WeightedSumFitter(PyMCModel): "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), } + def priors_from_data(self, X, y) -> Dict[str, Any]: + n_predictors = X.shape[1] + + return { + "beta": Prior("Dirichlet", a=np.ones(n_predictors), dims="coeffs"), + } + def build_model(self, X, y, coords): """ Defines the PyMC model """ with self: self.add_coords(coords) - n_predictors = X.shape[1] X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y[:, 0], dims="obs_ind") - beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 4a908d60..3e6a538d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 94.5% + interrogate: 93.6% @@ -12,8 +12,8 @@ interrogate interrogate - 94.5% - 94.5% + 93.6% + 93.6% From 57ba733249463991e7299df726a3a385ebabcd92 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Sat, 27 Sep 2025 11:00:23 +0100 Subject: [PATCH 08/14] Add regenerated interrogate badge with updated coverage --- docs/source/_static/interrogate_badge.svg | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 docs/source/_static/interrogate_badge.svg diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg new file mode 100644 index 00000000..10f9fc4a --- /dev/null +++ b/docs/source/_static/interrogate_badge.svg @@ -0,0 +1,58 @@ + + interrogate: 94.3% + + + + + + + + + + + interrogate + interrogate + 94.3% + 94.3% + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From b57810a6dbca10a7b95df9a0118cbe4b5948fca2 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 29 Sep 2025 09:33:56 +0100 Subject: [PATCH 09/14] update pymc-extras version pin in attempt to fix failing remote tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3975ad92..92c8451b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "seaborn>=0.11.2", "statsmodels", "xarray>=v2022.11.0", - "pymc-extras>=0.2.7", + "pymc-extras>=0.3.0", ] # List additional groups of dependencies here (e.g. development dependencies). Users From 787a10e5d7621e04a7d2826280103f5fb9600aeb Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 29 Sep 2025 10:12:58 +0100 Subject: [PATCH 10/14] Add pragma no cover to exception branches Added '# pragma: no cover' to NotImplementedError and ValueError branches in PyMCModel to exclude them from test coverage reporting. --- causalpy/pymc_models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 78ca6f30..50485b40 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -115,7 +115,9 @@ def __init__( def build_model(self, X, y, coords) -> None: """Build the model, must be implemented by subclass.""" - raise NotImplementedError("This method must be implemented by a subclass") + raise NotImplementedError( + "This method must be implemented by a subclass" + ) # pragma: no cover def _data_setter(self, X: xr.DataArray) -> None: """ @@ -274,7 +276,9 @@ def print_coefficients_for_unit( elif "y_hat_sigma" in self.idata.posterior: sigma_var_name = "y_hat_sigma" else: - raise ValueError("Neither 'sigma' nor 'y_hat_sigma' found in posterior") + raise ValueError( + "Neither 'sigma' nor 'y_hat_sigma' found in posterior" + ) # pragma: no cover treated_units = coeffs.coords["treated_units"].values for unit in treated_units: From bcba49f975f2978ee89e152627093ef2cc9ea95b Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 29 Sep 2025 10:18:19 +0100 Subject: [PATCH 11/14] update pymc-extras version pin to match that in pyproject.toml --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 2bc8ed20..a838de19 100644 --- a/environment.yml +++ b/environment.yml @@ -15,4 +15,4 @@ dependencies: - seaborn>=0.11.2 - statsmodels - xarray>=v2022.11.0 - - pymc-extras>=0.2.7 + - pymc-extras>=0.3.0 From 06506443765adb5bc5a92523d121befd537d8475 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 29 Sep 2025 10:37:40 +0100 Subject: [PATCH 12/14] add docstrings to the priors_from_data methods --- causalpy/pymc_models.py | 94 +++++++++++++++++++++++ docs/source/_static/interrogate_badge.svg | 8 +- 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 50485b40..4cd03f8b 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -96,6 +96,77 @@ def default_priors(self): return {} def priors_from_data(self, X, y) -> Dict[str, Any]: + """ + Generate priors dynamically based on the input data. + + This method allows models to set sensible priors that adapt to the scale + and characteristics of the actual data being analyzed. It's called during + the `fit()` method before model building, allowing data-driven prior + specification that can improve model performance and convergence. + + The priors returned by this method are merged with any user-specified + priors (passed via the `priors` parameter in `__init__`), with + user-specified priors taking precedence in case of conflicts. + + Parameters + ---------- + X : xarray.DataArray + Input features/covariates with dimensions ["obs_ind", "coeffs"]. + Used to understand the scale and structure of predictors. + y : xarray.DataArray + Target variable with dimensions ["obs_ind", "treated_units"]. + Used to understand the scale and structure of the outcome. + + Returns + ------- + Dict[str, Prior] + Dictionary mapping parameter names to Prior objects. The keys should + match parameter names used in the model's `build_model()` method. + + Notes + ----- + The base implementation returns an empty dictionary, meaning no + data-driven priors are set by default. Subclasses should override + this method to implement data-adaptive prior specification. + + **Priority Order for Priors:** + 1. User-specified priors (passed to `__init__`) + 2. Data-driven priors (from this method) + 3. Default priors (from `default_priors` property) + + Examples + -------- + A typical implementation might scale priors based on data variance: + + >>> def priors_from_data(self, X, y): + ... y_std = float(y.std()) + ... return { + ... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"), + ... "beta": Prior( + ... "Normal", + ... mu=0, + ... sigma=2 * y_std, + ... dims=["treated_units", "coeffs"], + ... ), + ... } + + Or set shape parameters based on data dimensions: + + >>> def priors_from_data(self, X, y): + ... n_predictors = X.shape[1] + ... return { + ... "beta": Prior( + ... "Dirichlet", + ... a=np.ones(n_predictors), + ... dims=["treated_units", "coeffs"], + ... ) + ... } + + See Also + -------- + WeightedSumFitter.priors_from_data : Example implementation that sets + Dirichlet prior shape based on number of control units. + """ return {} def __init__( @@ -160,6 +231,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None: # sample_posterior_predictive() if provided in sample_kwargs. random_seed = self.sample_kwargs.get("random_seed", None) + # Merge priors with precedence: user-specified > data-driven > defaults + # Data-driven priors are computed first, then user-specified priors override them self.priors = {**self.priors_from_data(X, y), **self.priors} self.build_model(X, y, coords) @@ -407,6 +480,27 @@ class WeightedSumFitter(PyMCModel): } def priors_from_data(self, X, y) -> Dict[str, Any]: + """ + Set Dirichlet prior for weights based on number of control units. + + For synthetic control models, this method sets the shape parameter of the + Dirichlet prior on the control unit weights (`beta`) to be uniform across + all available control units. This ensures that all control units have + equal prior probability of contributing to the synthetic control. + + Parameters + ---------- + X : xarray.DataArray + Control unit data with shape (n_obs, n_control_units). + y : xarray.DataArray + Treated unit outcome data. + + Returns + ------- + Dict[str, Prior] + Dictionary containing: + - "beta": Dirichlet prior with shape=(1,...,1) for n_control_units + """ n_predictors = X.shape[1] return { "beta": Prior( diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 10f9fc4a..aa85b1ad 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 94.3% + interrogate: 95.1% - + @@ -12,8 +12,8 @@ interrogate interrogate - 94.3% - 94.3% + 95.1% + 95.1% From 3c659d3c0e6fa45b6a48ecb8e17ec0b412a95f61 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 29 Sep 2025 10:56:24 +0100 Subject: [PATCH 13/14] add tests --- causalpy/tests/test_pymc_models.py | 249 +++++++++++++++++++++- docs/source/_static/interrogate_badge.svg | 6 +- 2 files changed, 251 insertions(+), 4 deletions(-) diff --git a/causalpy/tests/test_pymc_models.py b/causalpy/tests/test_pymc_models.py index e5fc9582..62bdaf54 100644 --- a/causalpy/tests/test_pymc_models.py +++ b/causalpy/tests/test_pymc_models.py @@ -17,9 +17,10 @@ import pymc as pm import pytest import xarray as xr +from pymc_extras.prior import Prior import causalpy as cp -from causalpy.pymc_models import PyMCModel, WeightedSumFitter +from causalpy.pymc_models import LinearRegression, PyMCModel, WeightedSumFitter sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2} @@ -592,3 +593,249 @@ def test_r2_scores_differ_across_units(self, rng): f"R² standard deviation is too low ({r2_std}), suggesting insufficient variation " "between treated units. This might indicate a scoring implementation issue." ) + + +@pytest.fixture(scope="module") +def prior_test_data(): + """Generate test data for Prior integration tests (shared across all tests).""" + rng = np.random.default_rng(42) + X = xr.DataArray( + rng.normal(loc=0, scale=1, size=(20, 2)), + dims=["obs_ind", "coeffs"], + coords={"obs_ind": np.arange(20), "coeffs": ["x1", "x2"]}, + ) + y = xr.DataArray( + rng.normal(loc=0, scale=1, size=(20, 1)), + dims=["obs_ind", "treated_units"], + coords={"obs_ind": np.arange(20), "treated_units": ["unit_0"]}, + ) + coords = { + "obs_ind": np.arange(20), + "coeffs": ["x1", "x2"], + "treated_units": ["unit_0"], + } + return X, y, coords + + +class TestPriorIntegration: + """ + Test suite for Prior class integration with PyMC models. + Tests the precedence system, data-driven priors, and Prior class usage. + """ + + def test_default_priors_property(self): + """Test that default_priors property returns correct Prior objects.""" + model = LinearRegression() + defaults = model.default_priors + + # Check that defaults is a dictionary with expected keys + assert isinstance(defaults, dict) + assert "beta" in defaults + assert "y_hat" in defaults + + # Check that values are Prior objects + assert isinstance(defaults["beta"], Prior) + assert isinstance(defaults["y_hat"], Prior) + + # Check Prior configuration using correct API + beta_prior = defaults["beta"] + assert beta_prior.distribution == "Normal" + assert beta_prior.parameters["mu"] == 0 + assert beta_prior.parameters["sigma"] == 50 + + def test_priors_from_data_base_implementation(self, prior_test_data): + """Test that base PyMCModel.priors_from_data returns empty dict.""" + X, y, coords = prior_test_data + model = PyMCModel() + data_priors = model.priors_from_data(X, y) + assert isinstance(data_priors, dict) + assert len(data_priors) == 0 + + def test_weighted_sum_fitter_priors_from_data(self, prior_test_data): + """Test WeightedSumFitter data-driven Dirichlet prior generation.""" + X, y, coords = prior_test_data + model = WeightedSumFitter() + data_priors = model.priors_from_data(X, y) + + # Should return beta prior based on X shape + assert "beta" in data_priors + beta_prior = data_priors["beta"] + + # Check it's a Dirichlet prior using correct API + assert isinstance(beta_prior, Prior) + assert beta_prior.distribution == "Dirichlet" + + # Check shape matches number of predictors + assert len(beta_prior.parameters["a"]) == X.shape[1] # 2 predictors + assert np.allclose(beta_prior.parameters["a"], np.ones(2)) + + def test_prior_precedence_system(self, prior_test_data): + """Test that user priors override data-driven priors override defaults.""" + X, y, coords = prior_test_data + # Create custom user prior + user_beta_prior = Prior( + "Normal", mu=100, sigma=10, dims=("treated_units", "coeffs") + ) + + model = LinearRegression(priors={"beta": user_beta_prior}) + + # Before fit, should have user prior + defaults + assert model.priors["beta"] == user_beta_prior + assert "y_hat" in model.priors # From defaults + + # After calling priors_from_data, user prior should remain + data_priors = model.priors_from_data(X, y) + merged_priors = {**data_priors, **model.priors} + + # User prior should override any data-driven prior + assert merged_priors["beta"] == user_beta_prior + + def test_prior_precedence_integration_in_fit(self, prior_test_data): + """Test the complete prior precedence system during fit().""" + X, y, coords = prior_test_data + # Create model with custom user prior + custom_prior = Prior("Normal", mu=5, sigma=2, dims=("treated_units", "coeffs")) + model = LinearRegression( + priors={"beta": custom_prior}, + sample_kwargs={"tune": 5, "draws": 5, "chains": 1, "progressbar": False}, + ) + + # Fit the model + model.fit(X, y, coords=coords) + + # Check that the model was built with the custom prior + # We can verify this by checking the model context + assert model.idata is not None + assert "beta" in model.idata.posterior + + def test_prior_dimensions_consistency(self): + """Test that Prior dimensions are consistent with model expectations.""" + model = LinearRegression() + + # Check default priors have correct dimensions (tuples, not lists) + beta_prior = model.default_priors["beta"] + assert beta_prior.dims == ("treated_units", "coeffs") + + y_hat_prior = model.default_priors["y_hat"] + assert y_hat_prior.dims == ("obs_ind", "treated_units") + + # Check that sigma component has correct dims + sigma_prior = y_hat_prior.parameters["sigma"] + assert isinstance(sigma_prior, Prior) + assert sigma_prior.dims == ("treated_units",) + + def test_custom_prior_with_build_model(self, prior_test_data): + """Test that custom priors work correctly in build_model.""" + # Create a custom Prior with different parameters + custom_beta = Prior( + "Normal", + mu=0, + sigma=10, # Different from default (50) + dims=("treated_units", "coeffs"), + ) + custom_sigma = Prior( + "HalfNormal", + sigma=2, # Different from default (1) + dims=("treated_units",), + ) + custom_y_hat = Prior( + "Normal", sigma=custom_sigma, dims=("obs_ind", "treated_units") + ) + + model = LinearRegression(priors={"beta": custom_beta, "y_hat": custom_y_hat}) + + # Build the model to ensure priors work + X, y, coords = prior_test_data + model.build_model(X, y, coords) + + # Check that variables were created in the model context + with model: + assert "beta" in model.named_vars + assert "y_hat" in model.named_vars + + def test_prior_create_variable_integration(self, prior_test_data): + """Test that Prior.create_variable works in model context.""" + X, y, coords = prior_test_data + model = LinearRegression() + model.build_model(X, y, coords) + + # Verify that Prior.create_variable was called successfully + # by checking the created variables exist and have expected names + with model: + beta_var = model.named_vars["beta"] + # Check that the variable exists and is a PyMC variable + assert beta_var is not None + assert hasattr(beta_var, "name") + assert beta_var.name == "beta" + + def test_weighted_sum_fitter_dirichlet_prior_shape(self, prior_test_data): + """Test that WeightedSumFitter creates correct Dirichlet shape.""" + _, y, _ = prior_test_data + rng = np.random.default_rng(42) + # Test with different numbers of control units + for n_controls in [3, 5, 10]: + X = xr.DataArray( + rng.normal(size=(20, n_controls)), + dims=["obs_ind", "coeffs"], + coords={ + "obs_ind": np.arange(20), + "coeffs": [f"control_{i}" for i in range(n_controls)], + }, + ) + + model = WeightedSumFitter() + data_priors = model.priors_from_data(X, y) + + beta_prior = data_priors["beta"] + assert len(beta_prior.parameters["a"]) == n_controls + assert np.allclose(beta_prior.parameters["a"], np.ones(n_controls)) + + def test_prior_none_handling(self): + """Test that models handle None priors parameter correctly.""" + model = LinearRegression(priors=None) + + # Should still have default priors + assert len(model.priors) > 0 + assert "beta" in model.priors + assert "y_hat" in model.priors + + def test_empty_priors_dict(self): + """Test that models handle empty priors dict correctly.""" + model = LinearRegression(priors={}) + + # Should still have default priors + assert len(model.priors) > 0 + assert "beta" in model.priors + assert "y_hat" in model.priors + + def test_priors_from_data_called_during_fit(self, prior_test_data): + """Test that priors_from_data is called and integrated during fit.""" + + # Create a mock model that tracks priors_from_data calls + class TrackingWeightedSumFitter(WeightedSumFitter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priors_from_data_called = False + self.priors_from_data_args = None + + def priors_from_data(self, X, y): + self.priors_from_data_called = True + self.priors_from_data_args = (X, y) + return super().priors_from_data(X, y) + + model = TrackingWeightedSumFitter( + sample_kwargs={"tune": 2, "draws": 2, "chains": 1, "progressbar": False} + ) + + # Fit the model + X, y, coords = prior_test_data + model.fit(X, y, coords=coords) + + # Verify priors_from_data was called with correct arguments + assert model.priors_from_data_called + assert model.priors_from_data_args is not None + + # Verify the model has the Dirichlet prior after fitting + assert "beta" in model.priors + beta_prior = model.priors["beta"] + assert beta_prior.distribution == "Dirichlet" diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index aa85b1ad..d2d886ad 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 95.1% + interrogate: 95.4% @@ -12,8 +12,8 @@ interrogate interrogate - 95.1% - 95.1% + 95.4% + 95.4% From 4be4cdd4eb885fb135d75657ccfcf7283c2eb0bf Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Fri, 3 Oct 2025 09:47:43 +0100 Subject: [PATCH 14/14] Convert default_priors property to class attribute --- causalpy/pymc_models.py | 4 +--- docs/source/_static/interrogate_badge.svg | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 4cd03f8b..2596ca4c 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -91,9 +91,7 @@ class PyMCModel(pm.Model): Inference data... """ - @property - def default_priors(self): - return {} + default_priors = {} def priors_from_data(self, X, y) -> Dict[str, Any]: """ diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index d2d886ad..8734d55d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 95.4% + interrogate: 95.8% @@ -12,8 +12,8 @@ interrogate interrogate - 95.4% - 95.4% + 95.8% + 95.8%