|
| 1 | +import ast |
| 2 | +import logging |
| 3 | +import re |
| 4 | +from typing import Tuple |
| 5 | + |
| 6 | +import networkx as nx |
| 7 | +import numpy as np |
| 8 | +import scipy.stats |
| 9 | + |
| 10 | +from dowhy.gcm import AdditiveNoiseModel, EmpiricalDistribution, ScipyDistribution, StructuralCausalModel |
| 11 | +from dowhy.gcm.causal_mechanisms import StochasticModel |
| 12 | +from dowhy.gcm.causal_models import PARENTS_DURING_FIT |
| 13 | +from dowhy.gcm.ml.prediction_model import PredictionModel |
| 14 | +from dowhy.gcm.util.general import shape_into_2d |
| 15 | +from dowhy.graph import get_ordered_predecessors |
| 16 | + |
| 17 | +_STOCHASTIC_MODEL_TYPES = { |
| 18 | + "empirical": EmpiricalDistribution, |
| 19 | + "bayesiangaussianmixture": EmpiricalDistribution, |
| 20 | + "parametric": ScipyDistribution, |
| 21 | +} |
| 22 | +_NOISE_MODEL_PATTERN = r"^\s*([\w]+)\(([^)]*)\)\s*$" |
| 23 | +_NODE_NAME_PATTERN = r"[a-zA-Z_]\w*" |
| 24 | +_UNKNOWN_MODEL_PATTERN = rf"\s*\b{_NODE_NAME_PATTERN}(?:\s*,\s*{_NODE_NAME_PATTERN})*\b" |
| 25 | +_allowed_callables = {} |
| 26 | +_np_functions = {func: getattr(np, func) for func in dir(np) if callable(getattr(np, func))} |
| 27 | +_scipy_functions = { |
| 28 | + func: getattr(scipy.stats, func) for func in dir(scipy.stats) if callable(getattr(scipy.stats, func)) |
| 29 | +} |
| 30 | +_builtin_functions = {"len": len, "__builtins__": {}} |
| 31 | +_allowed_callables.update(_np_functions) |
| 32 | +_allowed_callables.update(_scipy_functions) |
| 33 | +_allowed_callables.update(_builtin_functions) |
| 34 | + |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | + |
| 38 | +def create_causal_model_from_equations(node_equations: str) -> StructuralCausalModel: |
| 39 | + """ |
| 40 | + Create a causal model from a set of equations defining causal relationships between nodes. |
| 41 | + The equation format supports the following cases in which expression can be defined: |
| 42 | + 1. Specifying root node equation: |
| 43 | + >>> "<node_name> = <noise_model_name>(<optional_arguments>)" |
| 44 | + The noise model name can be one of the following: |
| 45 | + - empirical() |
| 46 | + - bayesiangaussianmixture() |
| 47 | + - parametric() |
| 48 | + - <scipy.stats.*> |
| 49 | + Empirical and bayessian models are already defined and one can find the description |
| 50 | + of those in the dowhy library. |
| 51 | + Use parametric when you want to find the best continuous distribution for the data. |
| 52 | + You can specify any noise function defined in scipy\.stats library. |
| 53 | + 2. Specifying non-root node equation: |
| 54 | + >>> "<node_name> = <function-expression> + <noise_model_name>(<optional_arguments>)" |
| 55 | + The function-expression can be any expression containing airthmetic operations of the nodes |
| 56 | + and calling functions defined under numpy. The format/definition of noise for the non-root node |
| 57 | + remains same as in point one. |
| 58 | + 3. Specifying unknown causal model equation: |
| 59 | + >>> "Node -> <node_name1>, <node_name2>, ..." |
| 60 | + In case we don't know the causal relationship model between nodes then we can |
| 61 | + use the above format to just define the edges between the nodes. |
| 62 | + Example: |
| 63 | + >>> scm = \""" |
| 64 | + X = empirical() |
| 65 | + Z = norm(loc=0, scale=1) |
| 66 | + Y = 12 * X + log(Z) + norm(loc=0, scale=1) |
| 67 | + \""" |
| 68 | + :param node_equations: A string containing equations defining the relationships between nodes. |
| 69 | + Each equation should be separated by a newline. |
| 70 | + :return: StructuralCausalModel: A StructuralCausalModel object representing the created causal model. |
| 71 | + """ |
| 72 | + banned_characters = [":", ";", "[", "__", "import", "lambda"] |
| 73 | + causal_nodes_info = {} |
| 74 | + causal_graph = nx.DiGraph() |
| 75 | + for equation in node_equations.split("\n"): |
| 76 | + equation = equation.strip() |
| 77 | + _sanitize_input_expression(equation, banned_characters) |
| 78 | + if equation: |
| 79 | + parsed_args = {} |
| 80 | + node_name, expression = _extract_equation_components(equation) |
| 81 | + _check_node_redundancy(causal_nodes_info, node_name) |
| 82 | + causal_nodes_info[node_name] = {} |
| 83 | + root_node_match = re.match(_NOISE_MODEL_PATTERN, expression) |
| 84 | + unknown_model_match = re.match(_UNKNOWN_MODEL_PATTERN, expression) |
| 85 | + causal_graph.add_node(node_name) |
| 86 | + if root_node_match: |
| 87 | + causal_mechanism_name = root_node_match.group(1) |
| 88 | + args = root_node_match.group(2) |
| 89 | + parsed_args = _parse_args(args) |
| 90 | + causal_nodes_info[node_name]["causal_mechanism"] = _identify_noise_model( |
| 91 | + causal_mechanism_name, parsed_args |
| 92 | + ) |
| 93 | + elif unknown_model_match: |
| 94 | + parent_node_candidates = expression.split(",") |
| 95 | + parent_nodes = _get_sorted_parent_nodes(parent_node_candidates) |
| 96 | + _add_parent_nodes_to_graph(causal_graph, parent_nodes, node_name) |
| 97 | + causal_nodes_info[node_name]["unknown"] = True |
| 98 | + else: |
| 99 | + custom_func, noise_eq = expression.rsplit("+", 1) |
| 100 | + # Find all node names in the expression string. |
| 101 | + parent_node_candidates = re.findall(_NODE_NAME_PATTERN, custom_func) |
| 102 | + parent_nodes = _get_sorted_parent_nodes(parent_node_candidates) |
| 103 | + _add_parent_nodes_to_graph(causal_graph, parent_nodes, node_name) |
| 104 | + noise_model_name, parsed_args = _extract_noise_model_components(noise_eq) |
| 105 | + noise_model = _identify_noise_model(noise_model_name, parsed_args) |
| 106 | + causal_nodes_info[node_name]["causal_mechanism"] = AdditiveNoiseModel( |
| 107 | + CustomEquationModel(custom_func, parent_nodes), noise_model |
| 108 | + ) |
| 109 | + causal_nodes_info[node_name]["fully_defined"] = True if parsed_args else False |
| 110 | + _add_undefined_nodes_info(causal_nodes_info, list(causal_graph.nodes)) |
| 111 | + causal_model = StructuralCausalModel(causal_graph) |
| 112 | + for node in causal_graph.nodes: |
| 113 | + if not ("unknown" in causal_nodes_info[node]): |
| 114 | + causal_model.set_causal_mechanism(node, causal_nodes_info[node]["causal_mechanism"]) |
| 115 | + if causal_nodes_info[node]["fully_defined"]: |
| 116 | + causal_model.graph.nodes[node][PARENTS_DURING_FIT] = get_ordered_predecessors(causal_model.graph, node) |
| 117 | + return causal_model |
| 118 | + |
| 119 | + |
| 120 | +def _parse_args(args: str) -> dict: |
| 121 | + str_args_list = args.split(",") |
| 122 | + kwargs = {} |
| 123 | + for str_arg in str_args_list: |
| 124 | + if str_arg: |
| 125 | + arg_value_pairs = str_arg.split("=") |
| 126 | + kwargs[arg_value_pairs[0].strip()] = ast.literal_eval(arg_value_pairs[1].strip()) |
| 127 | + return kwargs |
| 128 | + |
| 129 | + |
| 130 | +def _add_parent_nodes_to_graph(causal_graph: nx.DiGraph, parent_nodes: list, node_name: str) -> None: |
| 131 | + for parent_node in parent_nodes: |
| 132 | + causal_graph.add_edge(parent_node, node_name) |
| 133 | + |
| 134 | + |
| 135 | +def _identify_noise_model(causal_mechanism_name: str, parsed_args: dict) -> StochasticModel: |
| 136 | + for model_type in _STOCHASTIC_MODEL_TYPES: |
| 137 | + if model_type == causal_mechanism_name: |
| 138 | + return _STOCHASTIC_MODEL_TYPES[model_type](**parsed_args) |
| 139 | + |
| 140 | + distribution = getattr(scipy.stats, causal_mechanism_name, None) |
| 141 | + if distribution: |
| 142 | + return _STOCHASTIC_MODEL_TYPES["parametric"](scipy_distribution=distribution, **parsed_args) |
| 143 | + raise ValueError(f"Unable to recognise the noise model: {causal_mechanism_name}") |
| 144 | + |
| 145 | + |
| 146 | +def _extract_noise_model_components(noise_eq: str) -> Tuple[str, dict]: |
| 147 | + noise_model_match = re.match(_NOISE_MODEL_PATTERN, noise_eq) |
| 148 | + if noise_model_match: |
| 149 | + noise_model_name = noise_model_match.group(1) |
| 150 | + args = noise_model_match.group(2) |
| 151 | + parsed_args = _parse_args(args) |
| 152 | + return noise_model_name, parsed_args |
| 153 | + else: |
| 154 | + raise Exception("Unable to recognise the format or function specified") |
| 155 | + |
| 156 | + |
| 157 | +def _extract_equation_components(equation: str) -> Tuple[str, str]: |
| 158 | + if "->" in equation: |
| 159 | + node_name, expression = equation.split("->", 1) |
| 160 | + else: |
| 161 | + node_name, expression = equation.split("=", 1) |
| 162 | + node_name = node_name.strip() |
| 163 | + expression = expression.strip() |
| 164 | + return node_name, expression |
| 165 | + |
| 166 | + |
| 167 | +def _get_sorted_parent_nodes(parent_node_candidates: list) -> list: |
| 168 | + parent_nodes = [] |
| 169 | + for candidate_node_name in parent_node_candidates: |
| 170 | + candidate_node_name = candidate_node_name.strip() |
| 171 | + if candidate_node_name not in _allowed_callables: |
| 172 | + parent_nodes.append(candidate_node_name) |
| 173 | + parent_nodes.sort() |
| 174 | + return parent_nodes |
| 175 | + |
| 176 | + |
| 177 | +def _add_undefined_nodes_info(causal_nodes_info: dict, present_nodes: list) -> None: |
| 178 | + for present_node in present_nodes: |
| 179 | + if present_node not in causal_nodes_info: |
| 180 | + logger.warning(f"{present_node} is undefined and will be considered as root node by default.") |
| 181 | + causal_nodes_info[present_node] = {} |
| 182 | + causal_nodes_info[present_node]["causal_mechanism"] = EmpiricalDistribution() |
| 183 | + causal_nodes_info[present_node]["fully_defined"] = False |
| 184 | + |
| 185 | + |
| 186 | +def _check_node_redundancy(causal_nodes_info: dict, node_name: str) -> None: |
| 187 | + if node_name in causal_nodes_info: |
| 188 | + raise Exception(f"The node {node_name} is specified twice which is not allowed.") |
| 189 | + |
| 190 | + |
| 191 | +def _sanitize_input_expression(expression: str, banned_characters: list) -> None: |
| 192 | + for char in banned_characters: |
| 193 | + if char in expression: |
| 194 | + raise ValueError(f"'{char}' in the expression '{expression}' is not allowed because of security reasons") |
| 195 | + if re.search(r"[^0-9\+\-\*\/]+\.[^0-9\+\-\*\/]+", expression): |
| 196 | + raise ValueError(f"'.' can only be used incase of specifying decimals because of security reasons") |
| 197 | + |
| 198 | + |
| 199 | +class CustomEquationModel(PredictionModel): |
| 200 | + """ |
| 201 | + Represents custom prediction model implementation. This model does not require to be fitted as the model has to be fully defined. |
| 202 | + """ |
| 203 | + |
| 204 | + def __init__(self, custom_func: str, parent_nodes: list): |
| 205 | + self.custom_func = custom_func |
| 206 | + self.parent_nodes = parent_nodes |
| 207 | + |
| 208 | + def fit(self, X, Y) -> None: |
| 209 | + # Nothing to fit here, since we know the ground truth. |
| 210 | + pass |
| 211 | + |
| 212 | + def predict(self, X) -> np.ndarray: |
| 213 | + local_dict = {self.parent_nodes[i]: X[:, i] for i in range(len(self.parent_nodes))} |
| 214 | + return shape_into_2d(eval(self.custom_func, _allowed_callables, local_dict)) |
| 215 | + |
| 216 | + def clone(self): |
| 217 | + return CustomEquationModel(self.custom_func, self.parent_nodes) |
0 commit comments