Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Jan 5, 2024
1 parent a21f659 commit db29b5b
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions causalpy/tests/test_pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,38 @@ def test_idata_property():
)
assert hasattr(result, "idata")
assert isinstance(result.idata, az.InferenceData)


def test_result_reproducibility():
"""Test that we can reproduce the results from the model. We could in theory test
this with all the model and experiment types, but what is being targetted is
the ModelBuilder.fit method, so we should be safe testing with just one model. Here
we use the DifferenceInDifferences experiment class."""
# Load the data
df = cp.load_data("did")
# Set a random seed
sample_kwargs["random_seed"] = 42
# Calculate the result twice
result1 = cp.pymc_experiments.DifferenceInDifferences(
df,
formula="y ~ 1 + group + t + group:post_treatment",
time_variable_name="t",
group_variable_name="group",
treated=1,
untreated=0,
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
)
result2 = cp.pymc_experiments.DifferenceInDifferences(
df,
formula="y ~ 1 + group + t + group:post_treatment",
time_variable_name="t",
group_variable_name="group",
treated=1,
untreated=0,
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
)
assert np.all(result1.idata.posterior.mu == result2.idata.posterior.mu)
assert np.all(result1.idata.prior.mu == result2.idata.prior.mu)
assert np.all(
result1.idata.prior_predictive.y_hat == result2.idata.prior_predictive.y_hat
)

0 comments on commit db29b5b

Please sign in to comment.