Skip to content

Commit 1765b69

Browse files
committed
added test
Signed-off-by: Amit Sharma <amit_sharma@live.com>
1 parent c5fea7e commit 1765b69

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/test_causal_model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,38 @@ def test_graph_input_nx(self, beta, num_instruments, num_samples, num_treatments
366366
all_nodes = model._graph.get_all_nodes(include_unobserved=False)
367367
assert "Unobserved Confounders" not in all_nodes
368368

369+
@mark.parametrize(
370+
["beta", "num_effect_modifiers", "num_samples"],
371+
[
372+
(10, 0, 100),
373+
(10, 1, 100),
374+
],
375+
)
376+
def test_cate_estimates_regression(self, beta, num_effect_modifiers, num_samples):
377+
data = dowhy.datasets.linear_dataset(
378+
beta=beta,
379+
num_common_causes=2,
380+
num_samples=num_samples,
381+
num_treatments=1,
382+
treatment_is_binary=True,
383+
num_effect_modifiers=num_effect_modifiers,
384+
)
385+
model = CausalModel(
386+
data=data["df"],
387+
treatment=data["treatment_name"],
388+
outcome=data["outcome_name"],
389+
graph=data["gml_graph"],
390+
test_significance=None,
391+
)
392+
identified_estimand = model.identify_effect()
393+
linear_estimate = model.estimate_effect(
394+
identified_estimand, method_name="backdoor.linear_regression", control_value=0, treatment_value=1
395+
)
396+
if num_effect_modifiers == 0:
397+
assert linear_estimate.conditional_estimates is None
398+
else:
399+
assert linear_estimate.conditional_estimates is not None
400+
369401
@mark.parametrize(
370402
["num_variables", "num_samples"],
371403
[

0 commit comments

Comments
 (0)