Skip to content

Commit 9e2bdb2

Browse files
committed
Bayesian power analysis pymc-labs#276
1 parent 83cb28c commit 9e2bdb2

File tree

4 files changed

+1723
-10
lines changed

4 files changed

+1723
-10
lines changed

causalpy/pymc_experiments.py

Lines changed: 282 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import warnings # noqa: I001
15-
from typing import Union
15+
from typing import Union, Dict
1616

1717
import arviz as az
1818
import matplotlib.pyplot as plt
@@ -26,7 +26,7 @@
2626
from causalpy.custom_exceptions import BadIndexException
2727
from causalpy.custom_exceptions import DataException, FormulaException
2828
from causalpy.plot_utils import plot_xY
29-
from causalpy.utils import _is_variable_dummy_coded
29+
from causalpy.utils import _is_variable_dummy_coded, compute_bayesian_tail_probability
3030

3131
LEGEND_FONT_SIZE = 12
3232
az.style.use("arviz-darkgrid")
@@ -330,15 +330,290 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
330330

331331
return fig, ax
332332

333-
def summary(self) -> None:
333+
def _summary_intervention(self, alpha: float = 0.05, **kwargs) -> pd.DataFrame:
334+
"""
335+
Calculate and summarize the intervention analysis results in a DataFrame format.
336+
337+
This function performs cumulative and mean calculations on the posterior predictive distributions,
338+
computes Bayesian tail probabilities, posterior estimations, causal effects, and confidence intervals.
339+
It optionally applies corrections to the cumulative and mean calculations.
340+
341+
Parameters:
342+
- alpha (float, optional): The significance level for confidence interval calculations. Default is 0.05.
343+
- kwargs (Dict[str, Any], optional): Additional keyword arguments.
344+
- "correction" (bool or Dict[str, float]): If True, applies predefined corrections to cumulative and mean results.
345+
If a dictionary, the corrections for 'cumulative' and 'mean' should be provided. Default is False.
346+
347+
Returns:
348+
- pd.DataFrame: A DataFrame where each row represents different statistical measures such as
349+
Bayesian tail probability, posterior estimation, causal effect, and confidence intervals for cumulative and mean results.
350+
"""
351+
correction = kwargs.get("correction", False)
352+
353+
results = {}
354+
ci = (alpha * 100) / 2
355+
356+
# Cumulative calculations
357+
cumulative_results = self.post_y.sum()
358+
_mu_samples_cumulative = (
359+
self.post_pred["posterior_predictive"]
360+
.mu.stack(sample=("chain", "draw"))
361+
.sum("obs_ind")
362+
)
363+
364+
# Mean calculations
365+
mean_results = self.post_y.mean()
366+
_mu_samples_mean = (
367+
self.post_pred["posterior_predictive"]
368+
.mu.stack(sample=("chain", "draw"))
369+
.mean("obs_ind")
370+
)
371+
372+
if not isinstance(correction, bool):
373+
_mu_samples_cumulative += correction["cumulative"]
374+
_mu_samples_mean += correction["mean"]
375+
376+
# Bayesian Tail Probability
377+
results["bayesian_tail_probability"] = {
378+
"cumulative": compute_bayesian_tail_probability(
379+
posterior=_mu_samples_cumulative, x=cumulative_results
380+
),
381+
"mean": compute_bayesian_tail_probability(
382+
posterior=_mu_samples_mean, x=mean_results
383+
),
384+
}
385+
386+
# Posterior Mean
387+
results["posterior_estimation"] = {
388+
"cumulative": np.mean(_mu_samples_cumulative.values),
389+
"mean": np.mean(_mu_samples_mean.values),
390+
}
391+
392+
results["results"] = {"cumulative": cumulative_results, "mean": mean_results}
393+
394+
# Causal Effect
395+
results["causal_effect"] = {
396+
"cumulative": cumulative_results
397+
- results["posterior_estimation"]["cumulative"],
398+
"mean": mean_results - results["posterior_estimation"]["mean"],
399+
}
400+
401+
# Confidence Intervals
402+
results["ci"] = {
403+
"cumulative": [
404+
np.percentile(_mu_samples_cumulative, ci),
405+
np.percentile(_mu_samples_cumulative, 100 - ci),
406+
],
407+
"mean": [
408+
np.percentile(_mu_samples_mean, ci),
409+
np.percentile(_mu_samples_mean, 100 - ci),
410+
],
411+
}
412+
413+
# Convert to DataFrame
414+
results_df = pd.DataFrame(results)
415+
416+
return results_df
417+
418+
def summary(self, version="coefficients", **kwargs) -> Union[None, pd.DataFrame]:
334419
"""
335420
Print text output summarising the results
336421
"""
422+
if version == "coefficients":
423+
print(f"{self.expt_type:=^80}")
424+
print(f"Formula: {self.formula}")
425+
# TODO: extra experiment specific outputs here
426+
self.print_coefficients()
427+
elif version == "intervention":
428+
return self._summary_intervention(**kwargs)
429+
430+
def _power_estimation(self, alpha: float = 0.05, correction: bool = False) -> Dict:
431+
"""
432+
Estimate the statistical power of an intervention based on cumulative and mean results.
337433
338-
print(f"{self.expt_type:=^80}")
339-
print(f"Formula: {self.formula}")
340-
# TODO: extra experiment specific outputs here
341-
self.print_coefficients()
434+
This function calculates posterior estimates, systematic differences, confidence intervals, and
435+
minimum detectable effects (MDE) for both cumulative and mean measures. It can apply corrections to
436+
account for systematic differences in the data.
437+
438+
Parameters:
439+
- alpha (float, optional): The significance level for confidence interval calculations. Default is 0.05.
440+
- correction (bool, optional): If True, applies corrections to account for systematic differences in
441+
cumulative and mean calculations. Default is False.
442+
443+
Returns:
444+
- Dict: A dictionary containing key statistical measures such as posterior estimation,
445+
systematic differences, confidence intervals, and posterior MDE for both cumulative and mean results.
446+
"""
447+
results = {}
448+
ci = (alpha * 100) / 2
449+
450+
# Cumulative calculations
451+
cumulative_results = self.post_y.sum()
452+
_mu_samples_cumulative = (
453+
self.post_pred["posterior_predictive"]
454+
.mu.stack(sample=("chain", "draw"))
455+
.sum("obs_ind")
456+
)
457+
458+
# Mean calculations
459+
mean_results = self.post_y.mean()
460+
_mu_samples_mean = (
461+
self.post_pred["posterior_predictive"]
462+
.mu.stack(sample=("chain", "draw"))
463+
.mean("obs_ind")
464+
)
465+
466+
# Posterior Mean
467+
results["posterior_estimation"] = {
468+
"cumulative": np.mean(_mu_samples_cumulative.values),
469+
"mean": np.mean(_mu_samples_mean.values),
470+
}
471+
472+
results["results"] = {"cumulative": cumulative_results, "mean": mean_results}
473+
474+
results["_systematic_differences"] = {
475+
"cumulative": results["results"]["cumulative"]
476+
- results["posterior_estimation"]["cumulative"],
477+
"mean": results["results"]["mean"]
478+
- results["posterior_estimation"]["mean"],
479+
}
480+
481+
if correction:
482+
_mu_samples_cumulative += results["_systematic_differences"]["cumulative"]
483+
_mu_samples_mean += results["_systematic_differences"]["mean"]
484+
485+
results["ci"] = {
486+
"cumulative": [
487+
np.percentile(_mu_samples_cumulative, ci),
488+
np.percentile(_mu_samples_cumulative, 100 - ci),
489+
],
490+
"mean": [
491+
np.percentile(_mu_samples_mean, ci),
492+
np.percentile(_mu_samples_mean, 100 - ci),
493+
],
494+
}
495+
496+
cumulative_upper_mde = (
497+
results["ci"]["cumulative"][1]
498+
- results["posterior_estimation"]["cumulative"]
499+
)
500+
cumulative_lower_mde = (
501+
results["posterior_estimation"]["cumulative"]
502+
- results["ci"]["cumulative"][0]
503+
)
504+
505+
mean_upper_mde = (
506+
results["ci"]["mean"][1] - results["posterior_estimation"]["mean"]
507+
)
508+
mean_lower_mde = (
509+
results["posterior_estimation"]["mean"] - results["ci"]["mean"][0]
510+
)
511+
512+
results["posterior_mde"] = {
513+
"cumulative": (cumulative_upper_mde + cumulative_lower_mde) / 2,
514+
"mean": (mean_upper_mde + mean_lower_mde) / 2,
515+
}
516+
return results
517+
518+
def power_summary(
519+
self, alpha: float = 0.05, correction: bool = False
520+
) -> pd.DataFrame:
521+
"""
522+
Summarize the power estimation results in a DataFrame format.
523+
524+
This function calls '_power_estimation' to perform power estimation calculations and then
525+
converts the resulting dictionary into a pandas DataFrame for easier analysis and visualization.
526+
527+
Parameters:
528+
- alpha (float, optional): The significance level for confidence interval calculations used in power estimation. Default is 0.05.
529+
- correction (bool, optional): Indicates whether to apply corrections in the power estimation process. Default is False.
530+
531+
Returns:
532+
- pd.DataFrame: A DataFrame representing the power estimation results, including posterior estimations,
533+
systematic differences, confidence intervals, and posterior MDE for cumulative and mean results.
534+
"""
535+
return pd.DataFrame(self._power_estimation(alpha=alpha, correction=correction))
536+
537+
def power_plot(self, alpha: float = 0.05, correction: bool = False) -> plt.Figure:
538+
"""
539+
Generate and return a figure containing plots that visualize power estimation results.
540+
541+
This function creates a two-panel plot (for mean and cumulative measures) to visualize the posterior distributions
542+
along with the confidence intervals, real mean, and posterior mean values. It allows for adjustments based on
543+
systematic differences if the correction is applied.
544+
545+
Parameters:
546+
- alpha (float, optional): The significance level for confidence interval calculations used in power estimation. Default is 0.05.
547+
- correction (bool, optional): Indicates whether to apply corrections for systematic differences in the plotting process. Default is False.
548+
549+
Returns:
550+
- plt.Figure: A matplotlib figure object containing the plots.
551+
"""
552+
_estimates = self._power_estimation(alpha=alpha, correction=correction)
553+
554+
fig, axs = plt.subplots(1, 2, figsize=(20, 6)) # Two subplots side by side
555+
556+
# Adjustments for Mean and Cumulative plots
557+
for i, key in enumerate(["mean", "cumulative"]):
558+
_mu_samples = self.post_pred["posterior_predictive"].mu.stack(
559+
sample=("chain", "draw")
560+
)
561+
if key == "mean":
562+
_mu_samples = _mu_samples.mean("obs_ind")
563+
elif key == "cumulative":
564+
_mu_samples = _mu_samples.sum("obs_ind")
565+
566+
if correction:
567+
_mu_samples += _estimates["_systematic_differences"][key]
568+
569+
# Histogram and KDE
570+
sns.histplot(
571+
_mu_samples,
572+
bins=30,
573+
kde=True,
574+
ax=axs[i],
575+
color="C0",
576+
stat="density",
577+
alpha=0.6,
578+
)
579+
kde_x, kde_y = (
580+
sns.kdeplot(_mu_samples, color="C1", fill=True, ax=axs[i])
581+
.get_lines()[0]
582+
.get_data()
583+
)
584+
585+
# Adjust y-limits based on max density
586+
max_density = max(kde_y)
587+
axs[i].set_ylim(0, max_density + 0.05 * max_density) # Adding 5% buffer
588+
589+
# Fill between for the percentile interval
590+
axs[i].fill_betweenx(
591+
y=np.linspace(0, max_density + 0.05 * max_density, 100),
592+
x1=_estimates["ci"][key][0],
593+
x2=_estimates["ci"][key][1],
594+
color="C0",
595+
alpha=0.3,
596+
label="C.I",
597+
)
598+
599+
# Vertical lines for the means
600+
axs[i].axvline(
601+
_estimates["results"][key], color="C3", linestyle="-", label="Real Mean"
602+
)
603+
if not correction:
604+
axs[i].axvline(
605+
_estimates["posterior_estimation"][key],
606+
color="C4",
607+
linestyle="--",
608+
label="Posterior Mean",
609+
)
610+
611+
axs[i].set_title(f"Posterior of mu ({key.capitalize()})")
612+
axs[i].set_xlabel("mu")
613+
axs[i].set_ylabel("Density")
614+
axs[i].legend()
615+
616+
return fig
342617

343618

344619
class InterruptedTimeSeries(PrePostFit):

causalpy/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
22
Utility functions
33
"""
4+
import numpy as np
45
import pandas as pd
6+
from scipy.stats import norm
57

68

79
def _is_variable_dummy_coded(series: pd.Series) -> bool:
@@ -13,3 +15,29 @@ def _is_variable_dummy_coded(series: pd.Series) -> bool:
1315
def _series_has_2_levels(series: pd.Series) -> bool:
1416
"""Check that the variable in the provided Series has 2 levels"""
1517
return len(pd.Categorical(series).categories) == 2
18+
19+
20+
def compute_bayesian_tail_probability(posterior, x) -> float:
21+
"""
22+
Calculate the probability of a given value being in a distribution defined by the posterior,
23+
24+
Args:
25+
- data: a list or array-like object containing the data to define the distribution
26+
- x: a numeric value for which to calculate the probability of being in the distribution
27+
28+
Returns:
29+
- prob: a numeric value representing the probability of x being in the distribution
30+
"""
31+
lower_bound, upper_bound = min(posterior), max(posterior)
32+
mean, std = np.mean(posterior), np.std(posterior)
33+
34+
cdf_lower = norm.cdf(lower_bound, mean, std)
35+
cdf_upper = 1 - norm.cdf(upper_bound, mean, std)
36+
cdf_x = norm.cdf(x, mean, std)
37+
38+
if cdf_x <= 0.5:
39+
probability = 2 * (cdf_x - cdf_lower) / (1 - cdf_lower - cdf_upper)
40+
else:
41+
probability = 2 * (1 - cdf_x + cdf_lower) / (1 - cdf_lower - cdf_upper)
42+
43+
return abs(round(probability, 2))

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)