Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion causalpy/experiments/prepostnegd.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment):
Intercept -0.5, 94% HDI [-1, 0.2]
C(group)[T.1] 2, 94% HDI [2, 2]
pre 1, 94% HDI [1, 1]
sigma 0.5, 94% HDI [0.5, 0.6]
y_hat_sigma 0.5, 94% HDI [0.5, 0.6]
"""

supports_ols = False
Expand Down
187 changes: 162 additions & 25 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import xarray as xr
from arviz import r2_score
from patsy import dmatrix
from pymc_extras.prior import Prior

from causalpy.utils import round_num

Expand Down Expand Up @@ -90,7 +91,87 @@ class PyMCModel(pm.Model):
Inference data...
"""

def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
default_priors = {}

def priors_from_data(self, X, y) -> Dict[str, Any]:
"""
Generate priors dynamically based on the input data.

This method allows models to set sensible priors that adapt to the scale
and characteristics of the actual data being analyzed. It's called during
the `fit()` method before model building, allowing data-driven prior
specification that can improve model performance and convergence.

The priors returned by this method are merged with any user-specified
priors (passed via the `priors` parameter in `__init__`), with
user-specified priors taking precedence in case of conflicts.

Parameters
----------
X : xarray.DataArray
Input features/covariates with dimensions ["obs_ind", "coeffs"].
Used to understand the scale and structure of predictors.
y : xarray.DataArray
Target variable with dimensions ["obs_ind", "treated_units"].
Used to understand the scale and structure of the outcome.

Returns
-------
Dict[str, Prior]
Dictionary mapping parameter names to Prior objects. The keys should
match parameter names used in the model's `build_model()` method.

Notes
-----
The base implementation returns an empty dictionary, meaning no
data-driven priors are set by default. Subclasses should override
this method to implement data-adaptive prior specification.

**Priority Order for Priors:**
1. User-specified priors (passed to `__init__`)
2. Data-driven priors (from this method)
3. Default priors (from `default_priors` property)

Examples
--------
A typical implementation might scale priors based on data variance:

>>> def priors_from_data(self, X, y):
... y_std = float(y.std())
... return {
... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
... "beta": Prior(
... "Normal",
... mu=0,
... sigma=2 * y_std,
... dims=["treated_units", "coeffs"],
... ),
... }

Or set shape parameters based on data dimensions:

>>> def priors_from_data(self, X, y):
... n_predictors = X.shape[1]
... return {
... "beta": Prior(
... "Dirichlet",
... a=np.ones(n_predictors),
... dims=["treated_units", "coeffs"],
... )
... }

See Also
--------
WeightedSumFitter.priors_from_data : Example implementation that sets
Dirichlet prior shape based on number of control units.
"""
return {}

def __init__(
self,
sample_kwargs: Optional[Dict[str, Any]] = None,
priors: dict[str, Any] | None = None,
):
"""
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
:func:`pymc.sample` function. Defaults to an empty dictionary.
Expand All @@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
self.idata = None
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}

self.priors = {**self.default_priors, **(priors or {})}

def build_model(self, X, y, coords) -> None:
"""Build the model, must be implemented by subclass."""
raise NotImplementedError("This method must be implemented by a subclass")
raise NotImplementedError(
"This method must be implemented by a subclass"
) # pragma: no cover

def _data_setter(self, X: xr.DataArray) -> None:
"""
Expand Down Expand Up @@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
# sample_posterior_predictive() if provided in sample_kwargs.
random_seed = self.sample_kwargs.get("random_seed", None)

# Merge priors with precedence: user-specified > data-driven > defaults
# Data-driven priors are computed first, then user-specified priors override them
self.priors = {**self.priors_from_data(X, y), **self.priors}

self.build_model(X, y, coords)
with self:
self.idata = pm.sample(**self.sample_kwargs)
Expand Down Expand Up @@ -239,26 +328,36 @@ def print_coefficients_for_unit(
) -> None:
"""Print coefficients for a single unit"""
# Determine the width of the longest label
max_label_length = max(len(name) for name in labels + ["sigma"])
max_label_length = max(len(name) for name in labels + ["y_hat_sigma"])

for name in labels:
coeff_samples = unit_coeffs.sel(coeffs=name)
print_row(max_label_length, name, coeff_samples, round_to)

# Add coefficient for measurement std
print_row(max_label_length, "sigma", unit_sigma, round_to)
print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to)

print("Model coefficients:")
coeffs = az.extract(self.idata.posterior, var_names="beta")

# Always has treated_units dimension - no branching needed!
# Check if sigma or y_hat_sigma variable exists
sigma_var_name = None
if "sigma" in self.idata.posterior:
sigma_var_name = "sigma"
elif "y_hat_sigma" in self.idata.posterior:
sigma_var_name = "y_hat_sigma"
else:
raise ValueError(
"Neither 'sigma' nor 'y_hat_sigma' found in posterior"
) # pragma: no cover

treated_units = coeffs.coords["treated_units"].values
for unit in treated_units:
if len(treated_units) > 1:
print(f"\nTreated unit: {unit}")

unit_coeffs = coeffs.sel(treated_units=unit)
unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel(
unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel(
treated_units=unit
)
print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2)
Expand Down Expand Up @@ -301,6 +400,15 @@ class LinearRegression(PyMCModel):
Inference data...
""" # noqa: W605

default_priors = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add @property decorator here? Or is that remembered from it being done in the PyMCModel base class?

Getting an Pylance warning: Type "dict[str, Prior]" is not assignable to declared type "property"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What line of code bring that on? Maybe having a setter will help?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, if you want a property, maybe we can have a setter method? (not a blocker for now and maybe create an issue?)

"beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]),
"y_hat": Prior(
"Normal",
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
dims=["obs_ind", "treated_units"],
),
}

def build_model(self, X, y, coords):
"""
Defines the PyMC model
Expand All @@ -314,12 +422,11 @@ def build_model(self, X, y, coords):
self.add_coords(coords)
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"])
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
beta = self.priors["beta"].create_variable("beta")
mu = pm.Deterministic(
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
)
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)


class WeightedSumFitter(PyMCModel):
Expand Down Expand Up @@ -362,23 +469,56 @@ class WeightedSumFitter(PyMCModel):
Inference data...
""" # noqa: W605

default_priors = {
"y_hat": Prior(
"Normal",
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
dims=["obs_ind", "treated_units"],
),
}

def priors_from_data(self, X, y) -> Dict[str, Any]:
"""
Set Dirichlet prior for weights based on number of control units.

For synthetic control models, this method sets the shape parameter of the
Dirichlet prior on the control unit weights (`beta`) to be uniform across
all available control units. This ensures that all control units have
equal prior probability of contributing to the synthetic control.

Parameters
----------
X : xarray.DataArray
Control unit data with shape (n_obs, n_control_units).
y : xarray.DataArray
Treated unit outcome data.

Returns
-------
Dict[str, Prior]
Dictionary containing:
- "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
"""
n_predictors = X.shape[1]
return {
"beta": Prior(
"Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
),
}

def build_model(self, X, y, coords):
"""
Defines the PyMC model
"""
with self:
self.add_coords(coords)
n_predictors = X.sizes["coeffs"]
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
beta = pm.Dirichlet(
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
)
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
beta = self.priors["beta"].create_variable("beta")
mu = pm.Deterministic(
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
)
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)


class InstrumentalVariableRegression(PyMCModel):
Expand Down Expand Up @@ -568,21 +708,18 @@ class PropensityScore(PyMCModel):
Inference...
""" # noqa: W605

def build_model(self, X, t, coords, prior, noncentred):
default_priors = {
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"),
}

def build_model(self, X, t, coords, prior=None, noncentred=True):
"Defines the PyMC propensity model"
with self:
self.add_coords(coords)
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
if noncentred:
mu_beta, sigma_beta = prior["b"]
beta_std = pm.Normal("beta_std", 0, 1, dims="coeffs")
b = pm.Deterministic(
"beta_", mu_beta + sigma_beta * beta_std, dims="coeffs"
)
else:
b = pm.Normal("b", mu=prior["b"][0], sigma=prior["b"][1], dims="coeffs")
mu = pm.math.dot(X_data, b)
b = self.priors["b"].create_variable("b")
mu = pt.dot(X_data, b)
p = pm.Deterministic("p", pm.math.invlogit(mu))
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")

Expand Down
Loading