diff --git a/causalml/metrics/sensitivity.py b/causalml/metrics/sensitivity.py index c25ecbcc..8a7686ed 100644 --- a/causalml/metrics/sensitivity.py +++ b/causalml/metrics/sensitivity.py @@ -147,19 +147,14 @@ def get_ate_ci(self, X, p, treatment, y): learner = self.learner from ..inference.meta.tlearner import BaseTLearner - if isinstance(learner, BaseTLearner): - ate, ate_lower, ate_upper = learner.estimate_ate( - X=X, treatment=treatment, y=y + try: + ate, ate_lower, ate_upper = self.learner.estimate_ate( + X=X, p=p, treatment=treatment, y=y, return_ci=True + ) + except TypeError: + ate, ate_lower, ate_upper = self.learner.estimate_ate( + X=X, p=p, treatment=treatment, y=y ) - else: - try: - ate, ate_lower, ate_upper = learner.estimate_ate( - X=X, p=p, treatment=treatment, y=y - ) - except TypeError: - ate, ate_lower, ate_upper = learner.estimate_ate( - X=X, treatment=treatment, y=y, return_ci=True - ) return ate[0], ate_lower[0], ate_upper[0] @staticmethod