Skip to content

Commit 6038f4d

Browse files
bhatt-priyaduttbloebp
authored andcommitted
equation to graph implementation with test cases added.
Signed-off-by: priyadutt <bhattpriyadutt@gmail.com>
1 parent e66ed36 commit 6038f4d

File tree

3 files changed

+319
-0
lines changed

3 files changed

+319
-0
lines changed

dowhy/gcm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,5 @@
3232
from .unit_change import unit_change
3333
from .validation import RejectionResult, refute_causal_structure, refute_invertible_model
3434
from .whatif import average_causal_effect, counterfactual_samples, interventional_samples
35+
36+
from .equation_parser import create_causal_model_from_equations # isort:skip

dowhy/gcm/equation_parser.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import ast
2+
import logging
3+
import re
4+
from typing import Tuple
5+
6+
import networkx as nx
7+
import numpy as np
8+
import scipy.stats
9+
10+
from dowhy.gcm import AdditiveNoiseModel, EmpiricalDistribution, ScipyDistribution, StructuralCausalModel
11+
from dowhy.gcm.causal_mechanisms import StochasticModel
12+
from dowhy.gcm.causal_models import PARENTS_DURING_FIT
13+
from dowhy.gcm.ml.prediction_model import PredictionModel
14+
from dowhy.gcm.util.general import shape_into_2d
15+
from dowhy.graph import get_ordered_predecessors
16+
17+
_STOCHASTIC_MODEL_TYPES = {
18+
"empirical": EmpiricalDistribution,
19+
"bayesiangaussianmixture": EmpiricalDistribution,
20+
"parametric": ScipyDistribution,
21+
}
22+
_NOISE_MODEL_PATTERN = r"^\s*([\w]+)\(([^)]*)\)\s*$"
23+
_NODE_NAME_PATTERN = r"[a-zA-Z_]\w*"
24+
_UNKNOWN_MODEL_PATTERN = rf"\s*\b{_NODE_NAME_PATTERN}(?:\s*,\s*{_NODE_NAME_PATTERN})*\b"
25+
_allowed_callables = {}
26+
_np_functions = {func: getattr(np, func) for func in dir(np) if callable(getattr(np, func))}
27+
_scipy_functions = {
28+
func: getattr(scipy.stats, func) for func in dir(scipy.stats) if callable(getattr(scipy.stats, func))
29+
}
30+
_builtin_functions = {"len": len, "__builtins__": {}}
31+
_allowed_callables.update(_np_functions)
32+
_allowed_callables.update(_scipy_functions)
33+
_allowed_callables.update(_builtin_functions)
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
def create_causal_model_from_equations(node_equations: str) -> StructuralCausalModel:
39+
"""
40+
Create a causal model from a set of equations defining causal relationships between nodes.
41+
The equation format supports the following cases in which expression can be defined:
42+
1. Specifying root node equation:
43+
>>> "<node_name> = <noise_model_name>(<optional_arguments>)"
44+
The noise model name can be one of the following:
45+
- empirical()
46+
- bayesiangaussianmixture()
47+
- parametric()
48+
- <scipy.stats.*>
49+
Empirical and bayessian models are already defined and one can find the description
50+
of those in the dowhy library.
51+
Use parametric when you want to find the best continuous distribution for the data.
52+
You can specify any noise function defined in scipy\.stats library.
53+
2. Specifying non-root node equation:
54+
>>> "<node_name> = <function-expression> + <noise_model_name>(<optional_arguments>)"
55+
The function-expression can be any expression containing airthmetic operations of the nodes
56+
and calling functions defined under numpy. The format/definition of noise for the non-root node
57+
remains same as in point one.
58+
3. Specifying unknown causal model equation:
59+
>>> "Node -> <node_name1>, <node_name2>, ..."
60+
In case we don't know the causal relationship model between nodes then we can
61+
use the above format to just define the edges between the nodes.
62+
Example:
63+
>>> scm = \"""
64+
X = empirical()
65+
Z = norm(loc=0, scale=1)
66+
Y = 12 * X + log(Z) + norm(loc=0, scale=1)
67+
\"""
68+
:param node_equations: A string containing equations defining the relationships between nodes.
69+
Each equation should be separated by a newline.
70+
:return: StructuralCausalModel: A StructuralCausalModel object representing the created causal model.
71+
"""
72+
banned_characters = [":", ";", "[", "__", "import", "lambda"]
73+
causal_nodes_info = {}
74+
causal_graph = nx.DiGraph()
75+
for equation in node_equations.split("\n"):
76+
equation = equation.strip()
77+
_sanitize_input_expression(equation, banned_characters)
78+
if equation:
79+
parsed_args = {}
80+
node_name, expression = _extract_equation_components(equation)
81+
_check_node_redundancy(causal_nodes_info, node_name)
82+
causal_nodes_info[node_name] = {}
83+
root_node_match = re.match(_NOISE_MODEL_PATTERN, expression)
84+
unknown_model_match = re.match(_UNKNOWN_MODEL_PATTERN, expression)
85+
causal_graph.add_node(node_name)
86+
if root_node_match:
87+
causal_mechanism_name = root_node_match.group(1)
88+
args = root_node_match.group(2)
89+
parsed_args = _parse_args(args)
90+
causal_nodes_info[node_name]["causal_mechanism"] = _identify_noise_model(
91+
causal_mechanism_name, parsed_args
92+
)
93+
elif unknown_model_match:
94+
parent_node_candidates = expression.split(",")
95+
parent_nodes = _get_sorted_parent_nodes(parent_node_candidates)
96+
_add_parent_nodes_to_graph(causal_graph, parent_nodes, node_name)
97+
causal_nodes_info[node_name]["unknown"] = True
98+
else:
99+
custom_func, noise_eq = expression.rsplit("+", 1)
100+
# Find all node names in the expression string.
101+
parent_node_candidates = re.findall(_NODE_NAME_PATTERN, custom_func)
102+
parent_nodes = _get_sorted_parent_nodes(parent_node_candidates)
103+
_add_parent_nodes_to_graph(causal_graph, parent_nodes, node_name)
104+
noise_model_name, parsed_args = _extract_noise_model_components(noise_eq)
105+
noise_model = _identify_noise_model(noise_model_name, parsed_args)
106+
causal_nodes_info[node_name]["causal_mechanism"] = AdditiveNoiseModel(
107+
CustomEquationModel(custom_func, parent_nodes), noise_model
108+
)
109+
causal_nodes_info[node_name]["fully_defined"] = True if parsed_args else False
110+
_add_undefined_nodes_info(causal_nodes_info, list(causal_graph.nodes))
111+
causal_model = StructuralCausalModel(causal_graph)
112+
for node in causal_graph.nodes:
113+
if not ("unknown" in causal_nodes_info[node]):
114+
causal_model.set_causal_mechanism(node, causal_nodes_info[node]["causal_mechanism"])
115+
if causal_nodes_info[node]["fully_defined"]:
116+
causal_model.graph.nodes[node][PARENTS_DURING_FIT] = get_ordered_predecessors(causal_model.graph, node)
117+
return causal_model
118+
119+
120+
def _parse_args(args: str) -> dict:
121+
str_args_list = args.split(",")
122+
kwargs = {}
123+
for str_arg in str_args_list:
124+
if str_arg:
125+
arg_value_pairs = str_arg.split("=")
126+
kwargs[arg_value_pairs[0].strip()] = ast.literal_eval(arg_value_pairs[1].strip())
127+
return kwargs
128+
129+
130+
def _add_parent_nodes_to_graph(causal_graph: nx.DiGraph, parent_nodes: list, node_name: str) -> None:
131+
for parent_node in parent_nodes:
132+
causal_graph.add_edge(parent_node, node_name)
133+
134+
135+
def _identify_noise_model(causal_mechanism_name: str, parsed_args: dict) -> StochasticModel:
136+
for model_type in _STOCHASTIC_MODEL_TYPES:
137+
if model_type == causal_mechanism_name:
138+
return _STOCHASTIC_MODEL_TYPES[model_type](**parsed_args)
139+
140+
distribution = getattr(scipy.stats, causal_mechanism_name, None)
141+
if distribution:
142+
return _STOCHASTIC_MODEL_TYPES["parametric"](scipy_distribution=distribution, **parsed_args)
143+
raise ValueError(f"Unable to recognise the noise model: {causal_mechanism_name}")
144+
145+
146+
def _extract_noise_model_components(noise_eq: str) -> Tuple[str, dict]:
147+
noise_model_match = re.match(_NOISE_MODEL_PATTERN, noise_eq)
148+
if noise_model_match:
149+
noise_model_name = noise_model_match.group(1)
150+
args = noise_model_match.group(2)
151+
parsed_args = _parse_args(args)
152+
return noise_model_name, parsed_args
153+
else:
154+
raise Exception("Unable to recognise the format or function specified")
155+
156+
157+
def _extract_equation_components(equation: str) -> Tuple[str, str]:
158+
if "->" in equation:
159+
node_name, expression = equation.split("->", 1)
160+
else:
161+
node_name, expression = equation.split("=", 1)
162+
node_name = node_name.strip()
163+
expression = expression.strip()
164+
return node_name, expression
165+
166+
167+
def _get_sorted_parent_nodes(parent_node_candidates: list) -> list:
168+
parent_nodes = []
169+
for candidate_node_name in parent_node_candidates:
170+
candidate_node_name = candidate_node_name.strip()
171+
if candidate_node_name not in _allowed_callables:
172+
parent_nodes.append(candidate_node_name)
173+
parent_nodes.sort()
174+
return parent_nodes
175+
176+
177+
def _add_undefined_nodes_info(causal_nodes_info: dict, present_nodes: list) -> None:
178+
for present_node in present_nodes:
179+
if present_node not in causal_nodes_info:
180+
logger.warning(f"{present_node} is undefined and will be considered as root node by default.")
181+
causal_nodes_info[present_node] = {}
182+
causal_nodes_info[present_node]["causal_mechanism"] = EmpiricalDistribution()
183+
causal_nodes_info[present_node]["fully_defined"] = False
184+
185+
186+
def _check_node_redundancy(causal_nodes_info: dict, node_name: str) -> None:
187+
if node_name in causal_nodes_info:
188+
raise Exception(f"The node {node_name} is specified twice which is not allowed.")
189+
190+
191+
def _sanitize_input_expression(expression: str, banned_characters: list) -> None:
192+
for char in banned_characters:
193+
if char in expression:
194+
raise ValueError(f"'{char}' in the expression '{expression}' is not allowed because of security reasons")
195+
if re.search(r"[^0-9\+\-\*\/]+\.[^0-9\+\-\*\/]+", expression):
196+
raise ValueError(f"'.' can only be used incase of specifying decimals because of security reasons")
197+
198+
199+
class CustomEquationModel(PredictionModel):
200+
"""
201+
Represents custom prediction model implementation. This model does not require to be fitted as the model has to be fully defined.
202+
"""
203+
204+
def __init__(self, custom_func: str, parent_nodes: list):
205+
self.custom_func = custom_func
206+
self.parent_nodes = parent_nodes
207+
208+
def fit(self, X, Y) -> None:
209+
# Nothing to fit here, since we know the ground truth.
210+
pass
211+
212+
def predict(self, X) -> np.ndarray:
213+
local_dict = {self.parent_nodes[i]: X[:, i] for i in range(len(self.parent_nodes))}
214+
return shape_into_2d(eval(self.custom_func, _allowed_callables, local_dict))
215+
216+
def clone(self):
217+
return CustomEquationModel(self.custom_func, self.parent_nodes)

tests/gcm/test_equation_parser.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import networkx as nx
2+
import numpy as np
3+
import pandas as pd
4+
from flaky import flaky
5+
from pytest import approx
6+
7+
from dowhy.gcm import (
8+
AdditiveNoiseModel,
9+
EmpiricalDistribution,
10+
ProbabilisticCausalModel,
11+
create_causal_model_from_equations,
12+
fit,
13+
interventional_samples,
14+
)
15+
from dowhy.gcm.ml import create_linear_regressor_with_given_parameters
16+
17+
18+
@flaky(max_runs=2)
19+
def test_equation_parser_fit_func_is_giving_correct_results():
20+
observations = _generate_data()
21+
22+
causal_model = ProbabilisticCausalModel(nx.DiGraph([("X0", "X1"), ("X0", "X2"), ("X2", "X3")]))
23+
_assign_causal_mechanisms(causal_model)
24+
25+
fit(causal_model, observations)
26+
normal_results = causal_model.causal_mechanism("X1")._prediction_model.predict(observations[["X0"]].to_numpy())
27+
normal_results = np.around(normal_results, 2)
28+
causal_model_from_eq = _get_causal_model_from_eq()
29+
fit(causal_model_from_eq, observations)
30+
eq_results = causal_model_from_eq.causal_mechanism("X1")._prediction_model.predict(observations[["X0"]].to_numpy())
31+
eq_results = np.around(eq_results, 2)
32+
assert np.array_equal(normal_results, eq_results)
33+
34+
35+
def test_variables_are_sorted_alphabetically_in_custom_predict_method():
36+
causal_model = create_causal_model_from_equations(
37+
"""
38+
A = norm(loc=0,scale=0.1)
39+
B = norm(loc=0, scale=0.1)
40+
Y = 0.5*B + 2*A+ norm(loc=0, scale=0.1)
41+
"""
42+
)
43+
A = np.random.normal(0, 0.1, 10)
44+
B = np.random.normal(0, 0.1, 10)
45+
Y = 0.5 * B + 2 * A
46+
47+
observations = pd.DataFrame({"A": A, "B": B, "Y": Y})
48+
eq_results = causal_model.causal_mechanism("Y")._prediction_model.predict(observations[["A", "B"]].to_numpy())
49+
assert np.array_equal(np.around(Y, 2), np.around(eq_results.ravel(), 2))
50+
51+
52+
def test_unknown_causal_model_relationship_is_undefined():
53+
causal_model = create_causal_model_from_equations(
54+
"""
55+
A = norm(loc=0,scale=0.1)
56+
B = norm(loc=0, scale=0.1)
57+
Y = 0.5*B + 2*A+ norm(loc=0, scale=0.1)
58+
Z->Y,A
59+
"""
60+
)
61+
assert "Z" in causal_model.graph.nodes
62+
try:
63+
mech = causal_model.causal_mechanism("Z")
64+
raise AssertionError("The causal mechanism is defined for unknown model node!")
65+
except KeyError as ke:
66+
pass
67+
68+
69+
def _generate_data():
70+
X0 = np.random.normal(0, 0.1, 100)
71+
X1 = 2 * X0
72+
X2 = 0.5 * X0
73+
X3 = 0.5 * X2
74+
observations = pd.DataFrame({"X0": X0, "X1": X1, "X2": X2, "X3": X3})
75+
return observations
76+
77+
78+
def _get_causal_model_from_eq():
79+
causal_model = create_causal_model_from_equations(
80+
"""
81+
X0 = norm(loc=0,scale=0.1)
82+
X1 = 2*X0 + norm(loc=0, scale=0.1)
83+
X2 = 0.5*X0 + norm(loc=0, scale=0.1)
84+
X3 = 0.5*X2 + norm(loc=0, scale=0.1)
85+
"""
86+
)
87+
return causal_model
88+
89+
90+
def _assign_causal_mechanisms(causal_model):
91+
causal_model.set_causal_mechanism("X0", EmpiricalDistribution())
92+
causal_model.set_causal_mechanism(
93+
"X1", AdditiveNoiseModel(create_linear_regressor_with_given_parameters(coefficients=np.array([2])))
94+
)
95+
causal_model.set_causal_mechanism(
96+
"X2", AdditiveNoiseModel(create_linear_regressor_with_given_parameters(coefficients=np.array([0.5])))
97+
)
98+
causal_model.set_causal_mechanism(
99+
"X3", AdditiveNoiseModel(create_linear_regressor_with_given_parameters(coefficients=np.array([0.5])))
100+
)

0 commit comments

Comments
 (0)