Skip to content

Commit

Permalink
Merge pull request #272 from pymc-labs/round_to
Browse files Browse the repository at this point in the history
User specified number of significant figures for numbers in plots
  • Loading branch information
drbenvincent authored Dec 22, 2023
2 parents fc28a3b + 198bde6 commit 70de921
Show file tree
Hide file tree
Showing 14 changed files with 569 additions and 250 deletions.
94 changes: 64 additions & 30 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
from patsy import build_design_matrices, dmatrices
from sklearn.linear_model import LinearRegression as sk_lin_reg

from causalpy.custom_exceptions import BadIndexException
from causalpy.custom_exceptions import DataException, FormulaException
from causalpy.custom_exceptions import (
BadIndexException, # NOQA
DataException,
FormulaException,
)
from causalpy.plot_utils import plot_xY
from causalpy.utils import _is_variable_dummy_coded
from causalpy.utils import _is_variable_dummy_coded, round_num

LEGEND_FONT_SIZE = 12
az.style.use("arviz-darkgrid")
Expand Down Expand Up @@ -228,9 +231,12 @@ def _input_validation(self, data, treatment_time):
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
)

def plot(self, counterfactual_label="Counterfactual", **kwargs):
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

Expand Down Expand Up @@ -275,8 +281,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):

ax[0].set(
title=f"""
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
(std = {self.score.r2_std:.3f})
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
(std = {round_num(self.score.r2_std, round_to)})
"""
)

Expand Down Expand Up @@ -416,7 +422,11 @@ class SyntheticControl(PrePostFit):
expt_type = "Synthetic Control"

def plot(self, plot_predictors=False, **kwargs):
"""Plot the results"""
"""Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
if plot_predictors:
# plot control units as well
Expand Down Expand Up @@ -580,9 +590,11 @@ def _input_validation(self):
coded. Consisting of 0's and 1's only."""
)

def plot(self):
def plot(self, round_to=None):
"""Plot the results.
Creating the combined mean + HDI legend entries is a bit involved.
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots()

Expand Down Expand Up @@ -658,7 +670,7 @@ def plot(self):
# formatting
ax.set(
xticks=self.x_pred_treatment[self.time_variable_name].values,
title=self._causal_impact_summary_stat(),
title=self._causal_impact_summary_stat(round_to),
)
ax.legend(
handles=(h_tuple for h_tuple in handles),
Expand Down Expand Up @@ -711,11 +723,14 @@ def _plot_causal_impact_arrow(self, ax):
va="center",
)

def _causal_impact_summary_stat(self) -> str:
def _causal_impact_summary_stat(self, round_to=None) -> str:
"""Computes the mean and 94% credible interval bounds for the causal impact."""
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
ci = "$CI_{94\\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
causal_impact = f"{self.causal_impact.mean():.2f}, "
ci = (
"$CI_{94\\%}$"
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
)
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
return f"Causal impact = {causal_impact + ci}"

def summary(self) -> None:
Expand Down Expand Up @@ -893,9 +908,12 @@ def _is_treated(self, x):
"""
return np.greater_equal(x, self.treatment_threshold)

def plot(self):
def plot(self, round_to=None):
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
Expand All @@ -918,12 +936,15 @@ def plot(self):
labels = ["Posterior mean"]

# create strings to compose title
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
r2 = f"Bayesian $R^2$ on all data = {title_info}"
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
ci = (
r"$CI_{94\%}$"
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
)
discon = f"""
Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f},
Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)},
"""
ax.set(title=r2 + "\n" + discon + ci)
# Intervention line
Expand Down Expand Up @@ -1104,9 +1125,12 @@ def _is_treated(self, x):
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
return np.greater_equal(x, self.kink_point)

def plot(self):
def plot(self, round_to=None):
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
Expand All @@ -1129,12 +1153,15 @@ def plot(self):
labels = ["Posterior mean"]

# create strings to compose title
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
r2 = f"Bayesian $R^2$ on all data = {title_info}"
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
ci = (
r"$CI_{94\%}$"
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
)
grad_change = f"""
Change in gradient = {self.gradient_change.mean():.2f},
Change in gradient = {round_num(self.gradient_change.mean(), round_to)},
"""
ax.set(title=r2 + "\n" + grad_change + ci)
# Intervention line
Expand Down Expand Up @@ -1210,9 +1237,9 @@ class PrePostNEGD(ExperimentalDesign):
Formula: post ~ 1 + C(group) + pre
<BLANKLINE>
Results:
Causal impact = 1.8, $CI_{94%}$[1.6, 2.0]
Causal impact = 1.8, $CI_{94%}$[1.7, 2.1]
Model coefficients:
Intercept -0.4, 94% HDI [-1.2, 0.2]
Intercept -0.4, 94% HDI [-1.1, 0.2]
C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
pre 1.0, 94% HDI [0.9, 1.1]
sigma 0.5, 94% HDI [0.4, 0.5]
Expand Down Expand Up @@ -1292,8 +1319,12 @@ def _input_validation(self) -> None:
"""
)

def plot(self):
"""Plot the results"""
def plot(self, round_to=None):
"""Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots(
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
)
Expand Down Expand Up @@ -1339,18 +1370,21 @@ def plot(self):
)

# Plot estimated caual impact / treatment effect
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1], round_to=round_to)
ax[1].set(title="Estimated treatment effect")
return fig, ax

def _causal_impact_summary_stat(self) -> str:
def _causal_impact_summary_stat(self, round_to) -> str:
"""Computes the mean and 94% credible interval bounds for the causal impact."""
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
ci = r"$CI_{94%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
ci = (
r"$CI_{94%}$"
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
)
causal_impact = f"{self.causal_impact.mean():.2f}, "
return f"Causal impact = {causal_impact + ci}"

def summary(self) -> None:
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results
"""
Expand All @@ -1359,7 +1393,7 @@ def summary(self) -> None:
print(f"Formula: {self.formula}")
print("\nResults:")
# TODO: extra experiment specific outputs here
print(self._causal_impact_summary_stat())
print(self._causal_impact_summary_stat(round_to))
self.print_coefficients()

def _get_treatment_effect_coeff(self) -> str:
Expand Down
48 changes: 35 additions & 13 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import seaborn as sns
from patsy import build_design_matrices, dmatrices

from causalpy.utils import round_num

LEGEND_FONT_SIZE = 12


Expand Down Expand Up @@ -113,8 +115,12 @@ def __init__(
# cumulative impact post
self.post_impact_cumulative = np.cumsum(self.post_impact)

def plot(self, counterfactual_label="Counterfactual", **kwargs):
"""Plot experiment results"""
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
"""Plot experiment results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

ax[0].plot(self.datapre.index, self.pre_y, "k.")
Expand All @@ -128,7 +134,9 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
ls=":",
c="k",
)
ax[0].set(title=f"$R^2$ on pre-intervention data = {self.score:.3f}")
ax[0].set(
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
)

ax[1].plot(self.datapre.index, self.pre_impact, "k.")
ax[1].plot(
Expand Down Expand Up @@ -258,9 +266,15 @@ class SyntheticControl(PrePostFit):
... )
"""

def plot(self, plot_predictors=False, **kwargs):
"""Plot the results"""
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
def plot(self, plot_predictors=False, round_to=None, **kwargs):
"""Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = super().plot(
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
)
if plot_predictors:
# plot control units as well
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
Expand Down Expand Up @@ -397,8 +411,12 @@ def __init__(
# TODO: THIS IS NOT YET CORRECT
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]

def plot(self):
"""Plot results"""
def plot(self, round_to=None):
"""Plot results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots()

# Plot raw data
Expand Down Expand Up @@ -462,7 +480,7 @@ def plot(self):
xlim=[-0.05, 1.1],
xticks=[0, 1],
xticklabels=["pre", "post"],
title=f"Causal impact = {self.causal_impact[0]:.2f}",
title=f"Causal impact = {round_num(self.causal_impact[0], round_to)}",
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
Expand Down Expand Up @@ -607,8 +625,12 @@ def _is_treated(self, x):
"""
return np.greater_equal(x, self.treatment_threshold)

def plot(self):
"""Plot results"""
def plot(self, round_to=None):
"""Plot results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
sns.scatterplot(
Expand All @@ -627,8 +649,8 @@ def plot(self):
label="model fit",
)
# create strings to compose title
r2 = f"$R^2$ on all data = {self.score:.3f}"
discon = f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}"
r2 = f"$R^2$ on all data = {round_num(self.score, round_to)}"
discon = f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold, round_to)}"
ax.set(title=r2 + "\n" + discon)
# Intervention line
ax.axvline(
Expand Down
22 changes: 21 additions & 1 deletion causalpy/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels, round_num


def test_dummy_coding():
Expand All @@ -24,3 +24,23 @@ def test_2_level_series():
assert _series_has_2_levels(pd.Series(["water", "tea", "coffee"])) is False
assert _series_has_2_levels(pd.Series([0, 1, 0, 1])) is True
assert _series_has_2_levels(pd.Series([0, 1, 0, 2])) is False


def test_round_num():
"""Test if the function to round numbers works correctly"""
assert round_num(0.12345, None) == "0.12"
assert round_num(0.12345, 0) == "0.1"
assert round_num(0.12345, 1) == "0.1"
assert round_num(0.12345, 2) == "0.12"
assert round_num(0.12345, 3) == "0.123"
assert round_num(0.12345, 4) == "0.1235"
assert round_num(0.12345, 5) == "0.12345"
assert round_num(0.12345, 6) == "0.12345"
assert round_num(123.456, None) == "123"
assert round_num(123.456, 1) == "123"
assert round_num(123.456, 2) == "123"
assert round_num(123.456, 3) == "123"
assert round_num(123.456, 4) == "123.5"
assert round_num(123.456, 5) == "123.46"
assert round_num(123.456, 6) == "123.456"
assert round_num(123.456, 7) == "123.456"
Loading

0 comments on commit 70de921

Please sign in to comment.