Skip to content

Commit

Permalink
auto identify the effect modifier columns for `effect' method for Eco…
Browse files Browse the repository at this point in the history
…nML estimators (#1061)

* auto identify the effect modifier columns

Signed-off-by: Amit Sharma <amit_sharma@live.com>

* fixed formatting errors

Signed-off-by: Amit Sharma <amit_sharma@live.com>

---------

Signed-off-by: Amit Sharma <amit_sharma@live.com>
  • Loading branch information
amit-sharma authored Nov 27, 2023
1 parent 7c015b7 commit 4fd0a92
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
11 changes: 6 additions & 5 deletions dowhy/causal_estimators/econml.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def estimate_effect(
# Changing shape to a list for a singleton value
# Note that self._control_value is assumed to be a singleton value
self._treatment_value = parse_state(self._treatment_value)

est = self.effect(X_test)
ate = np.mean(est, axis=0) # one value per treatment value

Expand Down Expand Up @@ -305,7 +304,6 @@ def apply_multitreatment(self, df: pd.DataFrame, fun: Callable, *args, **kwargs)
filtered_df = None
else:
filtered_df = df.values

for tv in self._treatment_value:
ests.append(
fun(
Expand All @@ -331,7 +329,8 @@ def effect(self, df: pd.DataFrame, *args, **kwargs) -> np.ndarray:
def effect_fun(filtered_df, T0, T1, *args, **kwargs):
return self.estimator.effect(filtered_df, T0=T0, T1=T1, *args, **kwargs)

return self.apply_multitreatment(df, effect_fun, *args, **kwargs)
Xdf = df[self._effect_modifier_names] if df is not None else df
return self.apply_multitreatment(Xdf, effect_fun, *args, **kwargs)

def effect_interval(self, df: pd.DataFrame, *args, **kwargs) -> np.ndarray:
"""
Expand All @@ -346,7 +345,8 @@ def effect_interval_fun(filtered_df, T0, T1, *args, **kwargs):
filtered_df, T0=T0, T1=T1, alpha=1 - self.confidence_level, *args, **kwargs
)

return self.apply_multitreatment(df, effect_interval_fun, *args, **kwargs)
Xdf = df[self._effect_modifier_names] if df is not None else df
return self.apply_multitreatment(Xdf, effect_interval_fun, *args, **kwargs)

def effect_inference(self, df: pd.DataFrame, *args, **kwargs):
"""
Expand All @@ -359,7 +359,8 @@ def effect_inference(self, df: pd.DataFrame, *args, **kwargs):
def effect_inference_fun(filtered_df, T0, T1, *args, **kwargs):
return self.estimator.effect_inference(filtered_df, T0=T0, T1=T1, *args, **kwargs)

return self.apply_multitreatment(df, effect_inference_fun, *args, **kwargs)
Xdf = df[self._effect_modifier_names] if df is not None else df
return self.apply_multitreatment(Xdf, effect_inference_fun, *args, **kwargs)

def effect_tt(self, df: pd.DataFrame, treatment_value, *args, **kwargs):
"""
Expand Down
57 changes: 53 additions & 4 deletions tests/causal_estimators/test_econml_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_backdoor_estimators(self):
data = datasets.linear_dataset(
10,
num_common_causes=4,
num_samples=10000,
num_samples=1000,
num_instruments=2,
num_effect_modifiers=2,
num_treatments=1,
Expand Down Expand Up @@ -59,18 +59,25 @@ def test_backdoor_estimators(self):
"fit_params": {},
},
)
# Checking that the CATE estimates are not identical
dml_cate_estimates_f = dml_estimate.cate_estimates.flatten()
assert pytest.approx(dml_cate_estimates_f[0], 0.01) != dml_cate_estimates_f[1]
# Test ContinuousTreatmentOrthoForest
orthoforest_estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.econml.orf.DMLOrthoForest",
target_units=lambda df: df["X0"] > 2,
method_params={"init_params": {"n_trees": 10}, "fit_params": {}},
)
# Checking that the CATE estimates are not identical
orthoforest_cate_estimates_f = orthoforest_estimate.cate_estimates.flatten()
assert pytest.approx(orthoforest_cate_estimates_f[0], 0.01) != orthoforest_cate_estimates_f[1]

# Test LinearDRLearner
data_binary = datasets.linear_dataset(
10,
num_common_causes=4,
num_samples=10000,
num_samples=1000,
num_instruments=2,
num_effect_modifiers=2,
treatment_is_binary=True,
Expand All @@ -94,14 +101,56 @@ def test_backdoor_estimators(self):
"fit_params": {},
},
)
drlearner_cate_estimates_f = drlearner_estimate.cate_estimates.flatten()
assert pytest.approx(drlearner_cate_estimates_f[0], 0.01) != drlearner_cate_estimates_f[1]

def test_metalearners(self):
data = datasets.linear_dataset(
10,
num_common_causes=4,
num_samples=1000,
num_instruments=2,
num_effect_modifiers=2,
num_treatments=1,
treatment_is_binary=True,
)
df = data["df"]
model = CausalModel(
data=data["df"],
treatment=data["treatment_name"],
outcome=data["outcome_name"],
effect_modifiers=data["effect_modifier_names"],
graph=data["gml_graph"],
)
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
# Test LinearDML
sl_estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.econml.metalearners.SLearner",
target_units="ate",
method_params={"init_params": {"overall_model": GradientBoostingRegressor()}, "fit_params": {}},
)
# checking that CATE estimates are not identical
sl_cate_estimates_f = sl_estimate.cate_estimates.flatten()
assert pytest.approx(sl_cate_estimates_f[0], 0.01) != sl_cate_estimates_f[1]

# predict on new data
sl_estimate_test = model.estimate_effect(
identified_estimand,
method_name="backdoor.econml.metalearners.SLearner",
fit_estimator=False,
target_units=data["df"].sample(frac=0.1),
)
sl_cate_estimates_test_f = sl_estimate_test.cate_estimates.flatten()
assert pytest.approx(sl_cate_estimates_test_f[0], 0.01) != sl_cate_estimates_test_f[1]

def test_iv_estimators(self):
keras = pytest.importorskip("keras")
# Setup data
data = datasets.linear_dataset(
10,
num_common_causes=4,
num_samples=10000,
num_samples=1000,
num_instruments=2,
num_effect_modifiers=2,
num_treatments=1,
Expand Down Expand Up @@ -164,7 +213,7 @@ def test_iv_estimators(self):
data = datasets.linear_dataset(
10,
num_common_causes=4,
num_samples=10000,
num_samples=1000,
num_instruments=1,
num_effect_modifiers=2,
num_treatments=1,
Expand Down

0 comments on commit 4fd0a92

Please sign in to comment.