Skip to content

Commit 8c0e668

Browse files
test case added for the bug
Signed-off-by: priyadutt <bhattpriyadutt@gmail.com>
1 parent 320b96f commit 8c0e668

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

tests/gcm/test_equation_parser.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
import numpy as np
33
import pandas as pd
44
from flaky import flaky
5-
from pytest import approx
65

76
from dowhy.gcm import (
87
AdditiveNoiseModel,
98
EmpiricalDistribution,
109
ProbabilisticCausalModel,
1110
create_causal_model_from_equations,
1211
fit,
13-
interventional_samples,
1412
)
13+
from dowhy.gcm.causal_mechanisms import ConditionalStochasticModel
1514
from dowhy.gcm.ml import create_linear_regressor_with_given_parameters
1615

1716

@@ -66,6 +65,22 @@ def test_unknown_causal_model_relationship_is_undefined():
6665
pass
6766

6867

68+
def test_known_causal_model_node_is_correctly_identified():
69+
causal_model = create_causal_model_from_equations(
70+
"""
71+
A = norm(loc=0,scale=0.1)
72+
B = norm(loc=0, scale=0.1)
73+
Y = 0.5*B + 2*A+ norm(loc=0, scale=0.1)
74+
Z->Y,A
75+
C = exp(A) + 5 * Z + parametric()
76+
"""
77+
)
78+
list_of_nodes = {"A", "B", "C", "Z", "Y"}
79+
list_of_nodes_from_graph = set(causal_model.graph.nodes)
80+
assert list_of_nodes.issubset(list_of_nodes_from_graph) and list_of_nodes_from_graph.issubset(list_of_nodes)
81+
assert isinstance(causal_model.causal_mechanism("C"), ConditionalStochasticModel)
82+
83+
6984
def _generate_data():
7085
X0 = np.random.normal(0, 0.1, 100)
7186
X1 = 2 * X0

0 commit comments

Comments
 (0)