From 4ebe1a704f245cd51e0e4ecfe0d9d5ba0b375d5b Mon Sep 17 00:00:00 2001 From: Rojan Shrestha Date: Tue, 29 Jul 2025 22:06:05 +0545 Subject: [PATCH 1/7] Added post_treatment_variable_name parameter and sklearn model summary for did --- causalpy/experiments/diff_in_diff.py | 65 +++++++++++++++++------ docs/source/_static/interrogate_badge.svg | 8 +-- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 04b62370..2b307f00 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -26,7 +26,6 @@ from causalpy.custom_exceptions import ( DataException, - FormulaException, ) from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel @@ -84,6 +83,7 @@ def __init__( formula: str, time_variable_name: str, group_variable_name: str, + post_treatment_variable_name: str = "post_treatment", model=None, **kwargs, ) -> None: @@ -95,6 +95,7 @@ def __init__( self.formula = formula self.time_variable_name = time_variable_name self.group_variable_name = group_variable_name + self.post_treatment_variable_name = post_treatment_variable_name self.input_validation() y, X = dmatrices(formula, self.data) @@ -128,6 +129,12 @@ def __init__( } self.model.fit(X=self.X, y=self.y, coords=COORDS) elif isinstance(self.model, RegressorMixin): + # For scikit-learn models, automatically set fit_intercept=False + # This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute + # without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary + # TODO: later, this should be handled in ScikitLearnAdaptor itself + if hasattr(self.model, "fit_intercept"): + self.model.fit_intercept = False self.model.fit(X=self.X, y=self.y) else: raise ValueError("Model type not recognized") @@ -173,7 +180,7 @@ def __init__( # just the treated group .query(f"{self.group_variable_name} == 1") # just the treatment period(s) - .query("post_treatment == True") + .query(f"{self.post_treatment_variable_name} == True") # drop the outcome variable .drop(self.outcome_variable_name, axis=1) # We may have multiple units per time point, we only want one time point @@ -189,7 +196,10 @@ def __init__( # INTERVENTION: set the interaction term between the group and the # post_treatment variable to zero. This is the counterfactual. for i, label in enumerate(self.labels): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): new_x.iloc[:, i] = 0 self.y_pred_counterfactual = self.model.predict(np.asarray(new_x)) @@ -198,16 +208,24 @@ def __init__( # This is the coefficient on the interaction term coeff_names = self.model.idata.posterior.coords["coeffs"].data for i, label in enumerate(coeff_names): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): self.causal_impact = self.model.idata.posterior["beta"].isel( {"coeffs": i} ) elif isinstance(self.model, RegressorMixin): # This is the coefficient on the interaction term - # TODO: CHECK FOR CORRECTNESS - self.causal_impact = ( - self.y_pred_treatment[1] - self.y_pred_counterfactual[0] - ).item() + # Store the coefficient into dictionary {intercept:value} + coef_map = dict(zip(self.labels, self.model.get_coeffs())) + # Create and find the interaction term based on the values user provided + interaction_term = ( + f"{self.group_variable_name}:{self.post_treatment_variable_name}" + ) + matched_key = next((k for k in coef_map if interaction_term in k), None) + att = coef_map.get(matched_key) + self.causal_impact = att else: raise ValueError("Model type not recognized") @@ -215,15 +233,28 @@ def __init__( def input_validation(self): """Validate the input data and model formula for correctness""" - if "post_treatment" not in self.formula: - raise FormulaException( - "A predictor called `post_treatment` should be in the formula" - ) - - if "post_treatment" not in self.data.columns: - raise DataException( - "Require a boolean column labelling observations which are `treated`" - ) + if ( + self.post_treatment_variable_name not in self.formula + or self.post_treatment_variable_name not in self.data.columns + ): + if self.post_treatment_variable_name == "post_treatment": + # Default case - user didn't specify custom name, so guide them to use "post_treatment" + raise DataException( + "Missing 'post_treatment' in formula or dataset.\n" + "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" + "1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n" + "2) and ensure dataset has boolean column 'post_treatment'.\n" + "To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." + ) + else: + # Custom case - user specified custom name, so remind them what they specified + raise DataException( + f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n" + f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " + f"please ensure:\n" + f"1) formula includes '{self.post_treatment_variable_name}'\n" + f"2) dataset has boolean column named '{self.post_treatment_variable_name}'" + ) if "unit" not in self.data.columns: raise DataException( diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 4704ef6c..3e6a538d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 95.5% + interrogate: 93.6% - + @@ -12,8 +12,8 @@ interrogate interrogate - 95.5% - 95.5% + 93.6% + 93.6% From 7fbb27a3d3f5214768b043f10d1504a92a0edce0 Mon Sep 17 00:00:00 2001 From: Rojan Shrestha Date: Wed, 30 Jul 2025 11:20:51 +0545 Subject: [PATCH 2/7] Refactor DiD validation: segregate FormulaException and DataException --- causalpy/experiments/diff_in_diff.py | 38 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 2b307f00..a093359a 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -26,6 +26,7 @@ from causalpy.custom_exceptions import ( DataException, + FormulaException, ) from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel @@ -233,27 +234,40 @@ def __init__( def input_validation(self): """Validate the input data and model formula for correctness""" - if ( - self.post_treatment_variable_name not in self.formula - or self.post_treatment_variable_name not in self.data.columns - ): + # Check if post_treatment_variable_name is in formula + if self.post_treatment_variable_name not in self.formula: + if self.post_treatment_variable_name == "post_treatment": + # Default case - user didn't specify custom name, so guide them to use "post_treatment" + raise FormulaException( + "Missing 'post_treatment' in formula.\n" + "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" + "Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n" + "Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." + ) + else: + # Custom case - user specified custom name, so remind them what they specified + raise FormulaException( + f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n" + f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " + f"please ensure formula includes '{self.post_treatment_variable_name}'" + ) + + # Check if post_treatment_variable_name is in data columns + if self.post_treatment_variable_name not in self.data.columns: if self.post_treatment_variable_name == "post_treatment": # Default case - user didn't specify custom name, so guide them to use "post_treatment" raise DataException( - "Missing 'post_treatment' in formula or dataset.\n" + "Missing 'post_treatment' column in dataset.\n" "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" - "1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n" - "2) and ensure dataset has boolean column 'post_treatment'.\n" - "To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." + "Ensure dataset has boolean column 'post_treatment'.\n" + "or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." ) else: # Custom case - user specified custom name, so remind them what they specified raise DataException( - f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n" + f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n" f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " - f"please ensure:\n" - f"1) formula includes '{self.post_treatment_variable_name}'\n" - f"2) dataset has boolean column named '{self.post_treatment_variable_name}'" + f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'" ) if "unit" not in self.data.columns: From c232d89411d82fb9883f3c2e134a780431b55f28 Mon Sep 17 00:00:00 2001 From: Rojan Shrestha Date: Tue, 5 Aug 2025 00:31:52 +0545 Subject: [PATCH 3/7] added validations for interactions, test coverage expanded to test interaction terms,more generic messages --- causalpy/experiments/diff_in_diff.py | 96 +++++++++++------ causalpy/tests/test_input_validation.py | 119 +++++++++++++++++++++- docs/source/_static/interrogate_badge.svg | 6 +- 3 files changed, 186 insertions(+), 35 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index a093359a..132cd2ae 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -15,6 +15,8 @@ Difference in differences """ +import re + import arviz as az import numpy as np import pandas as pd @@ -233,42 +235,21 @@ def __init__( return def input_validation(self): + # Validate formula structure and interaction interaction terms + self._validate_formula_interaction_terms() + """Validate the input data and model formula for correctness""" # Check if post_treatment_variable_name is in formula if self.post_treatment_variable_name not in self.formula: - if self.post_treatment_variable_name == "post_treatment": - # Default case - user didn't specify custom name, so guide them to use "post_treatment" - raise FormulaException( - "Missing 'post_treatment' in formula.\n" - "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" - "Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n" - "Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." - ) - else: - # Custom case - user specified custom name, so remind them what they specified - raise FormulaException( - f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n" - f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " - f"please ensure formula includes '{self.post_treatment_variable_name}'" - ) + raise FormulaException( + f"Missing required variable '{self.post_treatment_variable_name}' in formula" + ) # Check if post_treatment_variable_name is in data columns if self.post_treatment_variable_name not in self.data.columns: - if self.post_treatment_variable_name == "post_treatment": - # Default case - user didn't specify custom name, so guide them to use "post_treatment" - raise DataException( - "Missing 'post_treatment' column in dataset.\n" - "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" - "Ensure dataset has boolean column 'post_treatment'.\n" - "or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." - ) - else: - # Custom case - user specified custom name, so remind them what they specified - raise DataException( - f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n" - f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " - f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'" - ) + raise DataException( + f"Missing required column '{self.post_treatment_variable_name}' in dataset" + ) if "unit" not in self.data.columns: raise DataException( @@ -281,6 +262,61 @@ def input_validation(self): coded. Consisting of 0's and 1's only.""" ) + def _get_interaction_terms(self): + """ + Extract interaction terms from the formula. + Returns a list of interaction terms (those with '*' or ':'). + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Remove whitespace + formula = self.formula.replace(" ", "") + + # Extract right-hand side of the formula + rhs = formula.split("~")[1] + + # Split terms by '+' or '-' while keeping them intact + terms = re.split(r"(?=[+-])", rhs) + + # Clean up terms and get interaction terms (those with '*' or ':') + interaction_terms = [] + for term in terms: + # Remove leading + or - for processing + clean_term = term.lstrip("+-") + if any(indicator in clean_term for indicator in INTERACTION_INDICATORS): + interaction_terms.append(clean_term) + + return interaction_terms + + def _validate_formula_interaction_terms(self): + """ + Validate that the formula contains at most one interaction term and no three-way or higher-order interactions. + Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables. + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Get interaction terms + interaction_terms = self._get_interaction_terms() + + # Check for interaction terms with more than 2 variables (more than one '*' or ':') + for term in interaction_terms: + total_indicators = sum( + term.count(indicator) for indicator in INTERACTION_INDICATORS + ) + if ( + total_indicators >= 2 + ): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols) + raise FormulaException( + f"Formula contains interaction term with more than 2 variables: {term}. Only two-way interactions are allowed." + ) + + if len(interaction_terms) > 1: + raise FormulaException( + f"Formula contains more than 1 interaction term: {interaction_terms}. Maximum of 1 allowed." + ) + def summary(self, round_to=None) -> None: """Print summary of main results and model coefficients. diff --git a/causalpy/tests/test_input_validation.py b/causalpy/tests/test_input_validation.py index 43fd9208..69ca3753 100644 --- a/causalpy/tests/test_input_validation.py +++ b/causalpy/tests/test_input_validation.py @@ -30,18 +30,29 @@ def test_did_validation_post_treatment_formula(): - """Test that we get a FormulaException if do not include post_treatment in the - formula""" + """Test that we get a FormulaException for invalid formulas and missing post_treatment variables""" df = pd.DataFrame( { "group": [0, 0, 1, 1], "t": [0, 1, 0, 1], "unit": [0, 0, 1, 1], "post_treatment": [0, 1, 0, 1], + "male": [0, 1, 0, 1], # Additional variable for testing "y": [1, 2, 3, 4], } ) + df_with_custom = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "custom_post": [0, 1, 0, 1], # Custom column name + "y": [1, 2, 3, 4], + } + ) + + # Test 1: Missing post_treatment variable in formula with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Missing post_treatment variable in formula (duplicate test) with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 3: Custom post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df_with_custom, + formula="y ~ 1 + group*post_treatment", # Formula uses 'post_treatment' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # But user specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 4: Default post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + # post_treatment_variable_name defaults to "post_treatment" + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 5: Repeated interaction terms (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 6: Three-way interactions using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 7: Three-way interactions using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 8: Multiple different interaction terms using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 9: Multiple different interaction terms using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment + group:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_post_treatment_data(): """Test that we get a DataException if do not include post_treatment in the data""" @@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Custom post_treatment_variable_name but column doesn't exist in data + df_with_post = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "post_treatment": [0, 1, 0, 1], # Data has 'post_treatment' + "y": [1, 2, 3, 4], + } + ) + + with pytest.raises(DataException): + _ = cp.DifferenceInDifferences( + df_with_post, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # User specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_unit_data(): """Test that we get a DataException if do not include unit in the data""" diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 3e6a538d..08c36d5e 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 93.6% + interrogate: 92.6% @@ -12,8 +12,8 @@ interrogate interrogate - 93.6% - 93.6% + 92.6% + 92.6% From e222e9bbe4a7e2c0fd2407d4013d7ca050d77289 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Thu, 28 Aug 2025 17:17:31 +0100 Subject: [PATCH 4/7] get pre-commit checks to pass --- docs/source/_static/interrogate_badge.svg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 08c36d5e..aa85b1ad 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 92.6% + interrogate: 95.1% - + @@ -12,8 +12,8 @@ interrogate interrogate - 92.6% - 92.6% + 95.1% + 95.1% From aeca7071b24c7f7a2d886f3bf31923cb00f576d9 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 27 Oct 2025 16:07:24 +0000 Subject: [PATCH 5/7] Refactor interaction term extraction in DiD and utils Moved the interaction term extraction logic from DifferenceInDifferences to a new get_interaction_terms utility function. Updated relevant imports and tests to use the new function, improving code reuse and maintainability. --- causalpy/experiments/diff_in_diff.py | 40 +++++-------------- causalpy/tests/test_utils.py | 47 ++++++++++++++++++++++- causalpy/utils.py | 47 +++++++++++++++++++++++ docs/source/_static/interrogate_badge.svg | 6 +-- 4 files changed, 105 insertions(+), 35 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 132cd2ae..0d421eed 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -15,8 +15,6 @@ Difference in differences """ -import re - import arviz as az import numpy as np import pandas as pd @@ -32,7 +30,12 @@ ) from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel -from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num +from causalpy.utils import ( + _is_variable_dummy_coded, + convert_to_string, + get_interaction_terms, + round_num, +) from .base import BaseExperiment @@ -54,6 +57,8 @@ class DifferenceInDifferences(BaseExperiment): Name of the data column for the time variable :param group_variable_name: Name of the data column for the group variable + :param post_treatment_variable_name: + Name of the data column indicating post-treatment period (default: "post_treatment") :param model: A PyMC model for difference in differences @@ -262,33 +267,6 @@ def input_validation(self): coded. Consisting of 0's and 1's only.""" ) - def _get_interaction_terms(self): - """ - Extract interaction terms from the formula. - Returns a list of interaction terms (those with '*' or ':'). - """ - # Define interaction indicators - INTERACTION_INDICATORS = ["*", ":"] - - # Remove whitespace - formula = self.formula.replace(" ", "") - - # Extract right-hand side of the formula - rhs = formula.split("~")[1] - - # Split terms by '+' or '-' while keeping them intact - terms = re.split(r"(?=[+-])", rhs) - - # Clean up terms and get interaction terms (those with '*' or ':') - interaction_terms = [] - for term in terms: - # Remove leading + or - for processing - clean_term = term.lstrip("+-") - if any(indicator in clean_term for indicator in INTERACTION_INDICATORS): - interaction_terms.append(clean_term) - - return interaction_terms - def _validate_formula_interaction_terms(self): """ Validate that the formula contains at most one interaction term and no three-way or higher-order interactions. @@ -298,7 +276,7 @@ def _validate_formula_interaction_terms(self): INTERACTION_INDICATORS = ["*", ":"] # Get interaction terms - interaction_terms = self._get_interaction_terms() + interaction_terms = get_interaction_terms(self.formula) # Check for interaction terms with more than 2 variables (more than one '*' or ':') for term in interaction_terms: diff --git a/causalpy/tests/test_utils.py b/causalpy/tests/test_utils.py index 8dc95590..a2dea702 100644 --- a/causalpy/tests/test_utils.py +++ b/causalpy/tests/test_utils.py @@ -17,7 +17,12 @@ import pandas as pd -from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels, round_num +from causalpy.utils import ( + _is_variable_dummy_coded, + _series_has_2_levels, + get_interaction_terms, + round_num, +) def test_dummy_coding(): @@ -57,3 +62,43 @@ def test_round_num(): assert round_num(123.456, 5) == "123.46" assert round_num(123.456, 6) == "123.456" assert round_num(123.456, 7) == "123.456" + + +def test_get_interaction_terms(): + """Test if the function to extract interaction terms from formulas works correctly""" + # No interaction terms + assert get_interaction_terms("y ~ x1 + x2 + x3") == [] + assert get_interaction_terms("y ~ 1 + x1 + x2") == [] + + # Single interaction term with '*' + assert get_interaction_terms("y ~ x1 + x2*x3") == ["x2*x3"] + assert get_interaction_terms("y ~ 1 + group*post_treatment") == [ + "group*post_treatment" + ] + + # Single interaction term with ':' + assert get_interaction_terms("y ~ x1 + x2:x3") == ["x2:x3"] + assert get_interaction_terms("y ~ 1 + group:post_treatment") == [ + "group:post_treatment" + ] + + # Multiple interaction terms + assert get_interaction_terms("y ~ x1*x2 + x3*x4") == ["x1*x2", "x3*x4"] + assert get_interaction_terms("y ~ a:b + c*d") == ["a:b", "c*d"] + + # Three-way interaction + assert get_interaction_terms("y ~ x1*x2*x3") == ["x1*x2*x3"] + assert get_interaction_terms("y ~ a:b:c") == ["a:b:c"] + + # Formula with spaces (should be handled correctly) + assert get_interaction_terms("y ~ x1 + x2 * x3") == ["x2*x3"] + assert get_interaction_terms("y ~ 1 + group * post_treatment") == [ + "group*post_treatment" + ] + + # Mixed main effects and interactions + assert get_interaction_terms("y ~ 1 + x1 + x2 + x1*x2") == ["x1*x2"] + assert get_interaction_terms("y ~ x1 + x2*x3 + x4") == ["x2*x3"] + + # Formula with subtraction (edge case) + assert get_interaction_terms("y ~ x1*x2 - x3") == ["x1*x2"] diff --git a/causalpy/utils.py b/causalpy/utils.py index c64eb109..5b7c601b 100644 --- a/causalpy/utils.py +++ b/causalpy/utils.py @@ -15,6 +15,7 @@ Utility functions """ +import re from typing import Union import numpy as np @@ -84,3 +85,49 @@ def convert_to_string(x: Union[float, xr.DataArray], round_to: int = 2) -> str: raise ValueError( "Type not supported. Please provide a float or an xarray object." ) + + +def get_interaction_terms(formula: str) -> list[str]: + """ + Extract interaction terms from a statistical model formula. + + Parameters + ---------- + formula : str + A statistical model formula string (e.g., "y ~ x1 + x2*x3") + + Returns + ------- + list[str] + A list of interaction terms (those containing '*' or ':') + + Examples + -------- + >>> get_interaction_terms("y ~ 1 + x1 + x2*x3") + ['x2*x3'] + >>> get_interaction_terms("y ~ x1:x2 + x3") + ['x1:x2'] + >>> get_interaction_terms("y ~ x1 + x2 + x3") + [] + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Remove whitespace + formula_clean = formula.replace(" ", "") + + # Extract right-hand side of the formula + rhs = formula_clean.split("~")[1] + + # Split terms by '+' or '-' while keeping them intact + terms = re.split(r"(?=[+-])", rhs) + + # Clean up terms and get interaction terms (those with '*' or ':') + interaction_terms = [] + for term in terms: + # Remove leading + or - for processing + clean_term = term.lstrip("+-") + if any(indicator in clean_term for indicator in INTERACTION_INDICATORS): + interaction_terms.append(clean_term) + + return interaction_terms diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 392a876b..8734d55d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 96.2% + interrogate: 95.8% @@ -12,8 +12,8 @@ interrogate interrogate - 96.2% - 96.2% + 95.8% + 95.8% From 7891dd888c76777a1e99eafeffe02bd5979bc168 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 27 Oct 2025 16:12:34 +0000 Subject: [PATCH 6/7] update exception message when we detect more than one interaction term --- causalpy/experiments/diff_in_diff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 0d421eed..fde96463 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -292,7 +292,7 @@ def _validate_formula_interaction_terms(self): if len(interaction_terms) > 1: raise FormulaException( - f"Formula contains more than 1 interaction term: {interaction_terms}. Maximum of 1 allowed." + f"Formula contains {len(interaction_terms)} interaction terms: {interaction_terms}. Multiple interaction terms are not currently supported." ) def summary(self, round_to=None) -> None: From ea2ed69dbe2ae1b6fa743c306f066beec896e436 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Tue, 28 Oct 2025 14:10:35 +0000 Subject: [PATCH 7/7] updates to FormulaException wording --- causalpy/experiments/diff_in_diff.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index fde96463..7a453a5d 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -287,12 +287,14 @@ def _validate_formula_interaction_terms(self): total_indicators >= 2 ): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols) raise FormulaException( - f"Formula contains interaction term with more than 2 variables: {term}. Only two-way interactions are allowed." + f"Formula contains interaction term with more than 2 variables: {term}. " + "Three-way or higher-order interactions are not supported as they complicate interpretation of the causal effect." ) if len(interaction_terms) > 1: raise FormulaException( - f"Formula contains {len(interaction_terms)} interaction terms: {interaction_terms}. Multiple interaction terms are not currently supported." + f"Formula contains {len(interaction_terms)} interaction terms: {interaction_terms}. " + "Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect." ) def summary(self, round_to=None) -> None: