From 4fd0a92bd2fabbacfe6f225ea9637d3e8f08407e Mon Sep 17 00:00:00 2001 From: Amit Sharma Date: Mon, 27 Nov 2023 14:40:31 +0530 Subject: [PATCH] auto identify the effect modifier columns for `effect' method for EconML estimators (#1061) * auto identify the effect modifier columns Signed-off-by: Amit Sharma * fixed formatting errors Signed-off-by: Amit Sharma --------- Signed-off-by: Amit Sharma --- dowhy/causal_estimators/econml.py | 11 ++-- .../test_econml_estimator.py | 57 +++++++++++++++++-- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/dowhy/causal_estimators/econml.py b/dowhy/causal_estimators/econml.py index a98c178306..819d3211cb 100755 --- a/dowhy/causal_estimators/econml.py +++ b/dowhy/causal_estimators/econml.py @@ -245,7 +245,6 @@ def estimate_effect( # Changing shape to a list for a singleton value # Note that self._control_value is assumed to be a singleton value self._treatment_value = parse_state(self._treatment_value) - est = self.effect(X_test) ate = np.mean(est, axis=0) # one value per treatment value @@ -305,7 +304,6 @@ def apply_multitreatment(self, df: pd.DataFrame, fun: Callable, *args, **kwargs) filtered_df = None else: filtered_df = df.values - for tv in self._treatment_value: ests.append( fun( @@ -331,7 +329,8 @@ def effect(self, df: pd.DataFrame, *args, **kwargs) -> np.ndarray: def effect_fun(filtered_df, T0, T1, *args, **kwargs): return self.estimator.effect(filtered_df, T0=T0, T1=T1, *args, **kwargs) - return self.apply_multitreatment(df, effect_fun, *args, **kwargs) + Xdf = df[self._effect_modifier_names] if df is not None else df + return self.apply_multitreatment(Xdf, effect_fun, *args, **kwargs) def effect_interval(self, df: pd.DataFrame, *args, **kwargs) -> np.ndarray: """ @@ -346,7 +345,8 @@ def effect_interval_fun(filtered_df, T0, T1, *args, **kwargs): filtered_df, T0=T0, T1=T1, alpha=1 - self.confidence_level, *args, **kwargs ) - return self.apply_multitreatment(df, effect_interval_fun, *args, **kwargs) + Xdf = df[self._effect_modifier_names] if df is not None else df + return self.apply_multitreatment(Xdf, effect_interval_fun, *args, **kwargs) def effect_inference(self, df: pd.DataFrame, *args, **kwargs): """ @@ -359,7 +359,8 @@ def effect_inference(self, df: pd.DataFrame, *args, **kwargs): def effect_inference_fun(filtered_df, T0, T1, *args, **kwargs): return self.estimator.effect_inference(filtered_df, T0=T0, T1=T1, *args, **kwargs) - return self.apply_multitreatment(df, effect_inference_fun, *args, **kwargs) + Xdf = df[self._effect_modifier_names] if df is not None else df + return self.apply_multitreatment(Xdf, effect_inference_fun, *args, **kwargs) def effect_tt(self, df: pd.DataFrame, treatment_value, *args, **kwargs): """ diff --git a/tests/causal_estimators/test_econml_estimator.py b/tests/causal_estimators/test_econml_estimator.py index 97b6c8adf0..335b31555a 100644 --- a/tests/causal_estimators/test_econml_estimator.py +++ b/tests/causal_estimators/test_econml_estimator.py @@ -28,7 +28,7 @@ def test_backdoor_estimators(self): data = datasets.linear_dataset( 10, num_common_causes=4, - num_samples=10000, + num_samples=1000, num_instruments=2, num_effect_modifiers=2, num_treatments=1, @@ -59,6 +59,9 @@ def test_backdoor_estimators(self): "fit_params": {}, }, ) + # Checking that the CATE estimates are not identical + dml_cate_estimates_f = dml_estimate.cate_estimates.flatten() + assert pytest.approx(dml_cate_estimates_f[0], 0.01) != dml_cate_estimates_f[1] # Test ContinuousTreatmentOrthoForest orthoforest_estimate = model.estimate_effect( identified_estimand, @@ -66,11 +69,15 @@ def test_backdoor_estimators(self): target_units=lambda df: df["X0"] > 2, method_params={"init_params": {"n_trees": 10}, "fit_params": {}}, ) + # Checking that the CATE estimates are not identical + orthoforest_cate_estimates_f = orthoforest_estimate.cate_estimates.flatten() + assert pytest.approx(orthoforest_cate_estimates_f[0], 0.01) != orthoforest_cate_estimates_f[1] + # Test LinearDRLearner data_binary = datasets.linear_dataset( 10, num_common_causes=4, - num_samples=10000, + num_samples=1000, num_instruments=2, num_effect_modifiers=2, treatment_is_binary=True, @@ -94,6 +101,48 @@ def test_backdoor_estimators(self): "fit_params": {}, }, ) + drlearner_cate_estimates_f = drlearner_estimate.cate_estimates.flatten() + assert pytest.approx(drlearner_cate_estimates_f[0], 0.01) != drlearner_cate_estimates_f[1] + + def test_metalearners(self): + data = datasets.linear_dataset( + 10, + num_common_causes=4, + num_samples=1000, + num_instruments=2, + num_effect_modifiers=2, + num_treatments=1, + treatment_is_binary=True, + ) + df = data["df"] + model = CausalModel( + data=data["df"], + treatment=data["treatment_name"], + outcome=data["outcome_name"], + effect_modifiers=data["effect_modifier_names"], + graph=data["gml_graph"], + ) + identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) + # Test LinearDML + sl_estimate = model.estimate_effect( + identified_estimand, + method_name="backdoor.econml.metalearners.SLearner", + target_units="ate", + method_params={"init_params": {"overall_model": GradientBoostingRegressor()}, "fit_params": {}}, + ) + # checking that CATE estimates are not identical + sl_cate_estimates_f = sl_estimate.cate_estimates.flatten() + assert pytest.approx(sl_cate_estimates_f[0], 0.01) != sl_cate_estimates_f[1] + + # predict on new data + sl_estimate_test = model.estimate_effect( + identified_estimand, + method_name="backdoor.econml.metalearners.SLearner", + fit_estimator=False, + target_units=data["df"].sample(frac=0.1), + ) + sl_cate_estimates_test_f = sl_estimate_test.cate_estimates.flatten() + assert pytest.approx(sl_cate_estimates_test_f[0], 0.01) != sl_cate_estimates_test_f[1] def test_iv_estimators(self): keras = pytest.importorskip("keras") @@ -101,7 +150,7 @@ def test_iv_estimators(self): data = datasets.linear_dataset( 10, num_common_causes=4, - num_samples=10000, + num_samples=1000, num_instruments=2, num_effect_modifiers=2, num_treatments=1, @@ -164,7 +213,7 @@ def test_iv_estimators(self): data = datasets.linear_dataset( 10, num_common_causes=4, - num_samples=10000, + num_samples=1000, num_instruments=1, num_effect_modifiers=2, num_treatments=1,