Skip to content

Commit 6283c76

Browse files
authored
Merge pull request #286 from pymc-labs/random_seed
Make results fully reproducible
2 parents 4772b34 + ee8515e commit 6283c76

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

causalpy/pymc_experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ class PrePostFit(ExperimentalDesign):
143143
Formula: actual ~ 0 + a + g
144144
Model coefficients:
145145
a 0.6, 94% HDI [0.6, 0.6]
146-
g 0.3, 94% HDI [0.3, 0.3]
147-
sigma 0.7, 94% HDI [0.6, 0.9]
146+
g 0.4, 94% HDI [0.4, 0.4]
147+
sigma 0.8, 94% HDI [0.6, 0.9]
148148
"""
149149

150150
def __init__(

causalpy/pymc_models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,22 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
9696
"""Draw samples fromposterior, prior predictive, and posterior predictive
9797
distributions, placing them in the model's idata attribute.
9898
"""
99+
100+
# Ensure random_seed is used in sample_prior_predictive() and
101+
# sample_posterior_predictive() if provided in sample_kwargs.
102+
if "random_seed" in self.sample_kwargs:
103+
random_seed = self.sample_kwargs["random_seed"]
104+
else:
105+
random_seed = None
106+
99107
self.build_model(X, y, coords)
100108
with self.model:
101109
self.idata = pm.sample(**self.sample_kwargs)
102-
self.idata.extend(pm.sample_prior_predictive())
110+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
103111
self.idata.extend(
104-
pm.sample_posterior_predictive(self.idata, progressbar=False)
112+
pm.sample_posterior_predictive(
113+
self.idata, progressbar=False, random_seed=random_seed
114+
)
105115
)
106116
return self.idata
107117

causalpy/tests/test_pymc_models.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,42 @@ def test_idata_property():
123123
)
124124
assert hasattr(result, "idata")
125125
assert isinstance(result.idata, az.InferenceData)
126+
127+
128+
seeds = [1234, 42, 123456789]
129+
130+
131+
@pytest.mark.parametrize("seed", seeds)
132+
def test_result_reproducibility(seed):
133+
"""Test that we can reproduce the results from the model. We could in theory test
134+
this with all the model and experiment types, but what is being targetted is
135+
the ModelBuilder.fit method, so we should be safe testing with just one model. Here
136+
we use the DifferenceInDifferences experiment class."""
137+
# Load the data
138+
df = cp.load_data("did")
139+
# Set a random seed
140+
sample_kwargs["random_seed"] = seed
141+
# Calculate the result twice
142+
result1 = cp.pymc_experiments.DifferenceInDifferences(
143+
df,
144+
formula="y ~ 1 + group + t + group:post_treatment",
145+
time_variable_name="t",
146+
group_variable_name="group",
147+
treated=1,
148+
untreated=0,
149+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
150+
)
151+
result2 = cp.pymc_experiments.DifferenceInDifferences(
152+
df,
153+
formula="y ~ 1 + group + t + group:post_treatment",
154+
time_variable_name="t",
155+
group_variable_name="group",
156+
treated=1,
157+
untreated=0,
158+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
159+
)
160+
assert np.all(result1.idata.posterior.mu == result2.idata.posterior.mu)
161+
assert np.all(result1.idata.prior.mu == result2.idata.prior.mu)
162+
assert np.all(
163+
result1.idata.prior_predictive.y_hat == result2.idata.prior_predictive.y_hat
164+
)

0 commit comments

Comments
 (0)