diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py index a187b7aa..32c1ceb1 100644 --- a/causalpy/experiments/prepostnegd.py +++ b/causalpy/experiments/prepostnegd.py @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment): Intercept -0.5, 94% HDI [-1, 0.2] C(group)[T.1] 2, 94% HDI [2, 2] pre 1, 94% HDI [1, 1] - sigma 0.5, 94% HDI [0.5, 0.6] + y_hat_sigma 0.5, 94% HDI [0.5, 0.6] """ supports_ols = False diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 5564ce71..2596ca4c 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -23,6 +23,7 @@ import xarray as xr from arviz import r2_score from patsy import dmatrix +from pymc_extras.prior import Prior from causalpy.utils import round_num @@ -90,7 +91,87 @@ class PyMCModel(pm.Model): Inference data... """ - def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): + default_priors = {} + + def priors_from_data(self, X, y) -> Dict[str, Any]: + """ + Generate priors dynamically based on the input data. + + This method allows models to set sensible priors that adapt to the scale + and characteristics of the actual data being analyzed. It's called during + the `fit()` method before model building, allowing data-driven prior + specification that can improve model performance and convergence. + + The priors returned by this method are merged with any user-specified + priors (passed via the `priors` parameter in `__init__`), with + user-specified priors taking precedence in case of conflicts. + + Parameters + ---------- + X : xarray.DataArray + Input features/covariates with dimensions ["obs_ind", "coeffs"]. + Used to understand the scale and structure of predictors. + y : xarray.DataArray + Target variable with dimensions ["obs_ind", "treated_units"]. + Used to understand the scale and structure of the outcome. + + Returns + ------- + Dict[str, Prior] + Dictionary mapping parameter names to Prior objects. The keys should + match parameter names used in the model's `build_model()` method. + + Notes + ----- + The base implementation returns an empty dictionary, meaning no + data-driven priors are set by default. Subclasses should override + this method to implement data-adaptive prior specification. + + **Priority Order for Priors:** + 1. User-specified priors (passed to `__init__`) + 2. Data-driven priors (from this method) + 3. Default priors (from `default_priors` property) + + Examples + -------- + A typical implementation might scale priors based on data variance: + + >>> def priors_from_data(self, X, y): + ... y_std = float(y.std()) + ... return { + ... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"), + ... "beta": Prior( + ... "Normal", + ... mu=0, + ... sigma=2 * y_std, + ... dims=["treated_units", "coeffs"], + ... ), + ... } + + Or set shape parameters based on data dimensions: + + >>> def priors_from_data(self, X, y): + ... n_predictors = X.shape[1] + ... return { + ... "beta": Prior( + ... "Dirichlet", + ... a=np.ones(n_predictors), + ... dims=["treated_units", "coeffs"], + ... ) + ... } + + See Also + -------- + WeightedSumFitter.priors_from_data : Example implementation that sets + Dirichlet prior shape based on number of control units. + """ + return {} + + def __init__( + self, + sample_kwargs: Optional[Dict[str, Any]] = None, + priors: dict[str, Any] | None = None, + ): """ :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the :func:`pymc.sample` function. Defaults to an empty dictionary. @@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): self.idata = None self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} + self.priors = {**self.default_priors, **(priors or {})} + def build_model(self, X, y, coords) -> None: """Build the model, must be implemented by subclass.""" - raise NotImplementedError("This method must be implemented by a subclass") + raise NotImplementedError( + "This method must be implemented by a subclass" + ) # pragma: no cover def _data_setter(self, X: xr.DataArray) -> None: """ @@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None: # sample_posterior_predictive() if provided in sample_kwargs. random_seed = self.sample_kwargs.get("random_seed", None) + # Merge priors with precedence: user-specified > data-driven > defaults + # Data-driven priors are computed first, then user-specified priors override them + self.priors = {**self.priors_from_data(X, y), **self.priors} + self.build_model(X, y, coords) with self: self.idata = pm.sample(**self.sample_kwargs) @@ -239,26 +328,36 @@ def print_coefficients_for_unit( ) -> None: """Print coefficients for a single unit""" # Determine the width of the longest label - max_label_length = max(len(name) for name in labels + ["sigma"]) + max_label_length = max(len(name) for name in labels + ["y_hat_sigma"]) for name in labels: coeff_samples = unit_coeffs.sel(coeffs=name) print_row(max_label_length, name, coeff_samples, round_to) # Add coefficient for measurement std - print_row(max_label_length, "sigma", unit_sigma, round_to) + print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to) print("Model coefficients:") coeffs = az.extract(self.idata.posterior, var_names="beta") - # Always has treated_units dimension - no branching needed! + # Check if sigma or y_hat_sigma variable exists + sigma_var_name = None + if "sigma" in self.idata.posterior: + sigma_var_name = "sigma" + elif "y_hat_sigma" in self.idata.posterior: + sigma_var_name = "y_hat_sigma" + else: + raise ValueError( + "Neither 'sigma' nor 'y_hat_sigma' found in posterior" + ) # pragma: no cover + treated_units = coeffs.coords["treated_units"].values for unit in treated_units: if len(treated_units) > 1: print(f"\nTreated unit: {unit}") unit_coeffs = coeffs.sel(treated_units=unit) - unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel( + unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel( treated_units=unit ) print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2) @@ -301,6 +400,15 @@ class LinearRegression(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]), + "y_hat": Prior( + "Normal", + sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]), + dims=["obs_ind", "treated_units"], + ), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -314,12 +422,11 @@ def build_model(self, X, y, coords): self.add_coords(coords) X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims=["obs_ind", "treated_units"]) - beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"]) - sigma = pm.HalfNormal("sigma", 1, dims="treated_units") + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic( "mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"] ) - pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"]) + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class WeightedSumFitter(PyMCModel): @@ -362,23 +469,56 @@ class WeightedSumFitter(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "y_hat": Prior( + "Normal", + sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]), + dims=["obs_ind", "treated_units"], + ), + } + + def priors_from_data(self, X, y) -> Dict[str, Any]: + """ + Set Dirichlet prior for weights based on number of control units. + + For synthetic control models, this method sets the shape parameter of the + Dirichlet prior on the control unit weights (`beta`) to be uniform across + all available control units. This ensures that all control units have + equal prior probability of contributing to the synthetic control. + + Parameters + ---------- + X : xarray.DataArray + Control unit data with shape (n_obs, n_control_units). + y : xarray.DataArray + Treated unit outcome data. + + Returns + ------- + Dict[str, Prior] + Dictionary containing: + - "beta": Dirichlet prior with shape=(1,...,1) for n_control_units + """ + n_predictors = X.shape[1] + return { + "beta": Prior( + "Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"] + ), + } + def build_model(self, X, y, coords): """ Defines the PyMC model """ with self: self.add_coords(coords) - n_predictors = X.sizes["coeffs"] X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims=["obs_ind", "treated_units"]) - beta = pm.Dirichlet( - "beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"] - ) - sigma = pm.HalfNormal("sigma", 1, dims="treated_units") + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic( "mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"] ) - pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"]) + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class InstrumentalVariableRegression(PyMCModel): @@ -568,21 +708,18 @@ class PropensityScore(PyMCModel): Inference... """ # noqa: W605 - def build_model(self, X, t, coords, prior, noncentred): + default_priors = { + "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), + } + + def build_model(self, X, t, coords, prior=None, noncentred=True): "Defines the PyMC propensity model" with self: self.add_coords(coords) X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) t_data = pm.Data("t", t.flatten(), dims="obs_ind") - if noncentred: - mu_beta, sigma_beta = prior["b"] - beta_std = pm.Normal("beta_std", 0, 1, dims="coeffs") - b = pm.Deterministic( - "beta_", mu_beta + sigma_beta * beta_std, dims="coeffs" - ) - else: - b = pm.Normal("b", mu=prior["b"][0], sigma=prior["b"][1], dims="coeffs") - mu = pm.math.dot(X_data, b) + b = self.priors["b"].create_variable("b") + mu = pt.dot(X_data, b) p = pm.Deterministic("p", pm.math.invlogit(mu)) pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") diff --git a/causalpy/tests/test_pymc_models.py b/causalpy/tests/test_pymc_models.py index 22f3a045..62bdaf54 100644 --- a/causalpy/tests/test_pymc_models.py +++ b/causalpy/tests/test_pymc_models.py @@ -17,9 +17,10 @@ import pymc as pm import pytest import xarray as xr +from pymc_extras.prior import Prior import causalpy as cp -from causalpy.pymc_models import PyMCModel, WeightedSumFitter +from causalpy.pymc_models import LinearRegression, PyMCModel, WeightedSumFitter sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2} @@ -45,7 +46,7 @@ def build_model(self, X, y, coords): X_ = pm.Data(name="X", value=X, dims=["obs_ind", "coeffs"]) y_ = pm.Data(name="y", value=y, dims=["obs_ind", "treated_units"]) beta = pm.Normal("beta", mu=0, sigma=1, dims=["treated_units", "coeffs"]) - sigma = pm.HalfNormal("sigma", sigma=1, dims="treated_units") + sigma = pm.HalfNormal("y_hat_sigma", sigma=1, dims="treated_units") mu = pm.Deterministic( "mu", pm.math.dot(X_, beta.T), dims=["obs_ind", "treated_units"] ) @@ -159,7 +160,7 @@ def test_fit_predict(self, coords, rng, mock_pymc_sample) -> None: 2, 2 * 2, ) # (treated_units, coeffs, sample) - assert az.extract(data=model.idata, var_names=["sigma"]).shape == ( + assert az.extract(data=model.idata, var_names=["y_hat_sigma"]).shape == ( 1, 2 * 2, ) # (treated_units, sample) @@ -402,7 +403,7 @@ def test_multi_unit_coefficients(self, synthetic_control_data): # Extract coefficients beta = az.extract(wsf.idata.posterior, var_names="beta") - sigma = az.extract(wsf.idata.posterior, var_names="sigma") + sigma = az.extract(wsf.idata.posterior, var_names="y_hat_sigma") # Check beta dimensions: should be (sample, treated_units, coeffs) assert "treated_units" in beta.dims @@ -461,7 +462,7 @@ def test_print_coefficients_multi_unit(self, synthetic_control_data, capsys): assert control in output # Check that sigma is printed for each unit - assert output.count("sigma") == len(treated_units) + assert output.count("y_hat_sigma") == len(treated_units) def test_scoring_multi_unit(self, synthetic_control_data): """Test that scoring works with multiple treated units.""" @@ -592,3 +593,249 @@ def test_r2_scores_differ_across_units(self, rng): f"R² standard deviation is too low ({r2_std}), suggesting insufficient variation " "between treated units. This might indicate a scoring implementation issue." ) + + +@pytest.fixture(scope="module") +def prior_test_data(): + """Generate test data for Prior integration tests (shared across all tests).""" + rng = np.random.default_rng(42) + X = xr.DataArray( + rng.normal(loc=0, scale=1, size=(20, 2)), + dims=["obs_ind", "coeffs"], + coords={"obs_ind": np.arange(20), "coeffs": ["x1", "x2"]}, + ) + y = xr.DataArray( + rng.normal(loc=0, scale=1, size=(20, 1)), + dims=["obs_ind", "treated_units"], + coords={"obs_ind": np.arange(20), "treated_units": ["unit_0"]}, + ) + coords = { + "obs_ind": np.arange(20), + "coeffs": ["x1", "x2"], + "treated_units": ["unit_0"], + } + return X, y, coords + + +class TestPriorIntegration: + """ + Test suite for Prior class integration with PyMC models. + Tests the precedence system, data-driven priors, and Prior class usage. + """ + + def test_default_priors_property(self): + """Test that default_priors property returns correct Prior objects.""" + model = LinearRegression() + defaults = model.default_priors + + # Check that defaults is a dictionary with expected keys + assert isinstance(defaults, dict) + assert "beta" in defaults + assert "y_hat" in defaults + + # Check that values are Prior objects + assert isinstance(defaults["beta"], Prior) + assert isinstance(defaults["y_hat"], Prior) + + # Check Prior configuration using correct API + beta_prior = defaults["beta"] + assert beta_prior.distribution == "Normal" + assert beta_prior.parameters["mu"] == 0 + assert beta_prior.parameters["sigma"] == 50 + + def test_priors_from_data_base_implementation(self, prior_test_data): + """Test that base PyMCModel.priors_from_data returns empty dict.""" + X, y, coords = prior_test_data + model = PyMCModel() + data_priors = model.priors_from_data(X, y) + assert isinstance(data_priors, dict) + assert len(data_priors) == 0 + + def test_weighted_sum_fitter_priors_from_data(self, prior_test_data): + """Test WeightedSumFitter data-driven Dirichlet prior generation.""" + X, y, coords = prior_test_data + model = WeightedSumFitter() + data_priors = model.priors_from_data(X, y) + + # Should return beta prior based on X shape + assert "beta" in data_priors + beta_prior = data_priors["beta"] + + # Check it's a Dirichlet prior using correct API + assert isinstance(beta_prior, Prior) + assert beta_prior.distribution == "Dirichlet" + + # Check shape matches number of predictors + assert len(beta_prior.parameters["a"]) == X.shape[1] # 2 predictors + assert np.allclose(beta_prior.parameters["a"], np.ones(2)) + + def test_prior_precedence_system(self, prior_test_data): + """Test that user priors override data-driven priors override defaults.""" + X, y, coords = prior_test_data + # Create custom user prior + user_beta_prior = Prior( + "Normal", mu=100, sigma=10, dims=("treated_units", "coeffs") + ) + + model = LinearRegression(priors={"beta": user_beta_prior}) + + # Before fit, should have user prior + defaults + assert model.priors["beta"] == user_beta_prior + assert "y_hat" in model.priors # From defaults + + # After calling priors_from_data, user prior should remain + data_priors = model.priors_from_data(X, y) + merged_priors = {**data_priors, **model.priors} + + # User prior should override any data-driven prior + assert merged_priors["beta"] == user_beta_prior + + def test_prior_precedence_integration_in_fit(self, prior_test_data): + """Test the complete prior precedence system during fit().""" + X, y, coords = prior_test_data + # Create model with custom user prior + custom_prior = Prior("Normal", mu=5, sigma=2, dims=("treated_units", "coeffs")) + model = LinearRegression( + priors={"beta": custom_prior}, + sample_kwargs={"tune": 5, "draws": 5, "chains": 1, "progressbar": False}, + ) + + # Fit the model + model.fit(X, y, coords=coords) + + # Check that the model was built with the custom prior + # We can verify this by checking the model context + assert model.idata is not None + assert "beta" in model.idata.posterior + + def test_prior_dimensions_consistency(self): + """Test that Prior dimensions are consistent with model expectations.""" + model = LinearRegression() + + # Check default priors have correct dimensions (tuples, not lists) + beta_prior = model.default_priors["beta"] + assert beta_prior.dims == ("treated_units", "coeffs") + + y_hat_prior = model.default_priors["y_hat"] + assert y_hat_prior.dims == ("obs_ind", "treated_units") + + # Check that sigma component has correct dims + sigma_prior = y_hat_prior.parameters["sigma"] + assert isinstance(sigma_prior, Prior) + assert sigma_prior.dims == ("treated_units",) + + def test_custom_prior_with_build_model(self, prior_test_data): + """Test that custom priors work correctly in build_model.""" + # Create a custom Prior with different parameters + custom_beta = Prior( + "Normal", + mu=0, + sigma=10, # Different from default (50) + dims=("treated_units", "coeffs"), + ) + custom_sigma = Prior( + "HalfNormal", + sigma=2, # Different from default (1) + dims=("treated_units",), + ) + custom_y_hat = Prior( + "Normal", sigma=custom_sigma, dims=("obs_ind", "treated_units") + ) + + model = LinearRegression(priors={"beta": custom_beta, "y_hat": custom_y_hat}) + + # Build the model to ensure priors work + X, y, coords = prior_test_data + model.build_model(X, y, coords) + + # Check that variables were created in the model context + with model: + assert "beta" in model.named_vars + assert "y_hat" in model.named_vars + + def test_prior_create_variable_integration(self, prior_test_data): + """Test that Prior.create_variable works in model context.""" + X, y, coords = prior_test_data + model = LinearRegression() + model.build_model(X, y, coords) + + # Verify that Prior.create_variable was called successfully + # by checking the created variables exist and have expected names + with model: + beta_var = model.named_vars["beta"] + # Check that the variable exists and is a PyMC variable + assert beta_var is not None + assert hasattr(beta_var, "name") + assert beta_var.name == "beta" + + def test_weighted_sum_fitter_dirichlet_prior_shape(self, prior_test_data): + """Test that WeightedSumFitter creates correct Dirichlet shape.""" + _, y, _ = prior_test_data + rng = np.random.default_rng(42) + # Test with different numbers of control units + for n_controls in [3, 5, 10]: + X = xr.DataArray( + rng.normal(size=(20, n_controls)), + dims=["obs_ind", "coeffs"], + coords={ + "obs_ind": np.arange(20), + "coeffs": [f"control_{i}" for i in range(n_controls)], + }, + ) + + model = WeightedSumFitter() + data_priors = model.priors_from_data(X, y) + + beta_prior = data_priors["beta"] + assert len(beta_prior.parameters["a"]) == n_controls + assert np.allclose(beta_prior.parameters["a"], np.ones(n_controls)) + + def test_prior_none_handling(self): + """Test that models handle None priors parameter correctly.""" + model = LinearRegression(priors=None) + + # Should still have default priors + assert len(model.priors) > 0 + assert "beta" in model.priors + assert "y_hat" in model.priors + + def test_empty_priors_dict(self): + """Test that models handle empty priors dict correctly.""" + model = LinearRegression(priors={}) + + # Should still have default priors + assert len(model.priors) > 0 + assert "beta" in model.priors + assert "y_hat" in model.priors + + def test_priors_from_data_called_during_fit(self, prior_test_data): + """Test that priors_from_data is called and integrated during fit.""" + + # Create a mock model that tracks priors_from_data calls + class TrackingWeightedSumFitter(WeightedSumFitter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priors_from_data_called = False + self.priors_from_data_args = None + + def priors_from_data(self, X, y): + self.priors_from_data_called = True + self.priors_from_data_args = (X, y) + return super().priors_from_data(X, y) + + model = TrackingWeightedSumFitter( + sample_kwargs={"tune": 2, "draws": 2, "chains": 1, "progressbar": False} + ) + + # Fit the model + X, y, coords = prior_test_data + model.fit(X, y, coords=coords) + + # Verify priors_from_data was called with correct arguments + assert model.priors_from_data_called + assert model.priors_from_data_args is not None + + # Verify the model has the Dirichlet prior after fitting + assert "beta" in model.priors + beta_prior = model.priors["beta"] + assert beta_prior.distribution == "Dirichlet" diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 4704ef6c..8734d55d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 95.5% + interrogate: 95.8% @@ -12,8 +12,8 @@ interrogate interrogate - 95.5% - 95.5% + 95.8% + 95.8% diff --git a/environment.yml b/environment.yml index 02b7f920..a838de19 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - seaborn>=0.11.2 - statsmodels - xarray>=v2022.11.0 + - pymc-extras>=0.3.0 diff --git a/pyproject.toml b/pyproject.toml index 7892f04e..9dc0e453 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "seaborn>=0.11.2", "statsmodels", "xarray>=v2022.11.0", + "pymc-extras>=0.3.0", ] # List additional groups of dependencies here (e.g. development dependencies). Users