Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def __init__(self, model_final):
def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
sample_weight=None, freq_weight=None, sample_var=None, groups=None):
Y_res, T_res = nuisances
self._model_final.fit(X, T, T_res, Y_res, sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var)
self._model_final.fit(X, T, T_res, Y_res, **(filter_none_kwargs(sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var, groups=groups)))
return self

def predict(self, X=None):
Expand Down
28 changes: 20 additions & 8 deletions econml/inference/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,22 +465,28 @@ class StatsModelsInference(LinearModelFinalInference):
----------
cov_type : str, default 'HC1'
The type of covariance estimation method to use. Supported values are 'nonrobust',
'HC0', 'HC1'.
'HC0', 'HC1', 'clustered'.
cov_options : dict, optional
Additional options for covariance estimation. For clustered covariance, supports:
- 'group_correction': bool, default True. Whether to apply N_G/(N_G-1) correction.
- 'df_correction': bool, default True. Whether to apply (N-1)/(N-K) correction.
"""

def __init__(self, cov_type='HC1'):
if cov_type not in ['nonrobust', 'HC0', 'HC1']:
def __init__(self, cov_type='HC1', cov_options=None):
if cov_type not in ['nonrobust', 'HC0', 'HC1', 'clustered']:
raise ValueError("Unsupported cov_type; "
"must be one of 'nonrobust', "
"'HC0', 'HC1'")
"'HC0', 'HC1', 'clustered'")

self.cov_type = cov_type
self.cov_options = cov_options if cov_options is not None else {}

def prefit(self, estimator, *args, **kwargs):
super().prefit(estimator, *args, **kwargs)
assert not (self.model_final.fit_intercept), ("Inference can only be performed on models linear in "
"their features, but here fit_intercept is True")
self.model_final.cov_type = self.cov_type
self.model_final.cov_options = self.cov_options


class GenericModelFinalInferenceDiscrete(Inference):
Expand Down Expand Up @@ -660,21 +666,27 @@ class StatsModelsInferenceDiscrete(LinearModelFinalInferenceDiscrete):
----------
cov_type : str, default 'HC1'
The type of covariance estimation method to use. Supported values are 'nonrobust',
'HC0', 'HC1'.
'HC0', 'HC1', 'clustered'.
cov_options : dict, optional
Additional options for covariance estimation. For clustered covariance, supports:
- 'group_correction': bool, default True. Whether to apply N_G/(N_G-1) correction.
- 'df_correction': bool, default True. Whether to apply (N-1)/(N-K) correction.
"""

def __init__(self, cov_type='HC1'):
if cov_type not in ['nonrobust', 'HC0', 'HC1']:
def __init__(self, cov_type='HC1', cov_options=None):
if cov_type not in ['nonrobust', 'HC0', 'HC1', 'clustered']:
raise ValueError("Unsupported cov_type; "
"must be one of 'nonrobust', "
"'HC0', 'HC1'")
"'HC0', 'HC1', 'clustered'")

self.cov_type = cov_type
self.cov_options = cov_options if cov_options is not None else {}

def prefit(self, estimator, *args, **kwargs):
super().prefit(estimator, *args, **kwargs)
# need to set the fit args before the estimator is fit
self.model_final.cov_type = self.cov_type
self.model_final.cov_options = self.cov_options


class InferenceResults(metaclass=abc.ABCMeta):
Expand Down
10 changes: 7 additions & 3 deletions econml/iv/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
XT_res = self._combine(X, T_res)
XZ_res = self._combine(X, Z_res)
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var)
freq_weight=freq_weight, sample_var=sample_var, groups=groups)

self._model_final.fit(XZ_res, XT_res, Y_res, **filtered_kwargs)

Expand Down Expand Up @@ -376,14 +376,18 @@ def __init__(self, *,
mc_iters=None,
mc_agg='mean',
random_state=None,
allow_missing=False):
allow_missing=False,
cov_type="HC0",
cov_options=None):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xw = clone(model_t_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
self.model_z_xw = clone(model_z_xw, safe=False)
self.projection = projection
self.featurizer = clone(featurizer, safe=False)
self.fit_cate_intercept = fit_cate_intercept
self.cov_type = cov_type
self.cov_options = cov_options if cov_options is not None else {}

super().__init__(discrete_outcome=discrete_outcome,
discrete_instrument=discrete_instrument,
Expand All @@ -403,7 +407,7 @@ def _gen_featurizer(self):
return clone(self.featurizer, safe=False)

def _gen_model_final(self):
return StatsModels2SLS(cov_type="HC0")
return StatsModels2SLS(cov_type=self.cov_type, cov_options=self.cov_options)

def _gen_ortho_learner_model_final(self):
return _OrthoIVModelFinal(self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept)
Expand Down
Loading