@@ -366,6 +366,38 @@ def test_graph_input_nx(self, beta, num_instruments, num_samples, num_treatments
366
366
all_nodes = model ._graph .get_all_nodes (include_unobserved = False )
367
367
assert "Unobserved Confounders" not in all_nodes
368
368
369
+ @mark .parametrize (
370
+ ["beta" , "num_effect_modifiers" , "num_samples" ],
371
+ [
372
+ (10 , 0 , 100 ),
373
+ (10 , 1 , 100 ),
374
+ ],
375
+ )
376
+ def test_cate_estimates_regression (self , beta , num_effect_modifiers , num_samples ):
377
+ data = dowhy .datasets .linear_dataset (
378
+ beta = beta ,
379
+ num_common_causes = 2 ,
380
+ num_samples = num_samples ,
381
+ num_treatments = 1 ,
382
+ treatment_is_binary = True ,
383
+ num_effect_modifiers = num_effect_modifiers ,
384
+ )
385
+ model = CausalModel (
386
+ data = data ["df" ],
387
+ treatment = data ["treatment_name" ],
388
+ outcome = data ["outcome_name" ],
389
+ graph = data ["gml_graph" ],
390
+ test_significance = None ,
391
+ )
392
+ identified_estimand = model .identify_effect ()
393
+ linear_estimate = model .estimate_effect (
394
+ identified_estimand , method_name = "backdoor.linear_regression" , control_value = 0 , treatment_value = 1
395
+ )
396
+ if num_effect_modifiers == 0 :
397
+ assert linear_estimate .conditional_estimates is None
398
+ else :
399
+ assert linear_estimate .conditional_estimates is not None
400
+
369
401
@mark .parametrize (
370
402
["num_variables" , "num_samples" ],
371
403
[
0 commit comments