diff --git a/dowhy/gcm/causal_graph_from_eq.py b/dowhy/gcm/equation_parser.py similarity index 63% rename from dowhy/gcm/causal_graph_from_eq.py rename to dowhy/gcm/equation_parser.py index 5a236aa407..8b2603cee0 100644 --- a/dowhy/gcm/causal_graph_from_eq.py +++ b/dowhy/gcm/equation_parser.py @@ -1,31 +1,29 @@ import re import ast import networkx as nx +import numexpr as ne import numpy as np import scipy.stats from dowhy.gcm.causal_mechanisms import StochasticModel from dowhy.gcm import EmpiricalDistribution, ScipyDistribution, StructuralCausalModel, AdditiveNoiseModel from dowhy.gcm.ml.prediction_model import PredictionModel -_TYPES_OF_STOCHASTIC_MODELS = {"empirical": EmpiricalDistribution, "bayesiangaussianmixture": EmpiricalDistribution, +STOCHASTIC_MODEL_TYPES = {"empirical": EmpiricalDistribution, "bayesiangaussianmixture": EmpiricalDistribution, "parametric": ScipyDistribution} -stocastic_function_names = f'({"|".join(list(_TYPES_OF_STOCHASTIC_MODELS.keys()))})' -parent_node_pattern_for_eq = r'\b[A-Za-z_][A-Za-z_0-9 ]*\b' -# root_node_pattern = rf'^\s*(.+)\s*=\s*({stocastic_function_names})\(([^)]*)\)\s*$' -noise_model_pattern = rf'^\s*([\w]+)\(([^)]*)\)\s*$' +NOISE_MODEL_PATTERN = rf'^\s*([\w]+)\(([^)]*)\)\s*$' -def create_causal_model_from_eq(node_equations: str): +def create_causal_model_from_equations(node_equations: str): graph_node_pairs = [] causal_nodes_info = {} for equation in node_equations.split('\n'): equation = equation.strip() if equation: - node_name, expression = get_equation_components(equation) + node_name, expression = extract_equation_components(equation) if not (node_name in causal_nodes_info): causal_nodes_info[node_name] = {} print("Variable Name:", node_name) - root_node_match = re.match(noise_model_pattern, expression) + root_node_match = re.match(NOISE_MODEL_PATTERN, expression) if root_node_match: causal_mechanism_name = root_node_match.group(1) print(causal_mechanism_name) @@ -35,16 +33,17 @@ def create_causal_model_from_eq(node_equations: str): parsed_args) else: custom_func, noise_eq = expression.rsplit('+', 1) - graph_node_pairs += get_node_pairs(custom_func, node_name) - noise_model_name, parsed_args = get_noise_model_components(noise_eq) + parent_nodes = extract_parent_nodes(custom_func) + graph_node_pairs += [(parent_node, node_name) for parent_node in parent_nodes] + noise_model_name, parsed_args = extract_noise_model_components(noise_eq) noise_model = identify_noise_model(noise_model_name, parsed_args) - causal_nodes_info[node_name]['causal_mechanism'] = AdditiveNoiseModel(MyCustomModel(custom_func), + causal_nodes_info[node_name]['causal_mechanism'] = AdditiveNoiseModel(MyCustomModel(custom_func, parent_nodes), noise_model) causal_graph = nx.DiGraph(graph_node_pairs) causal_model = StructuralCausalModel(causal_graph) for node in causal_graph.nodes: - causal_model.set_causal_mechanism(node, causal_nodes_info[node_name]['causal_mechanism']) + causal_model.set_causal_mechanism(node, causal_nodes_info[node]['causal_mechanism']) return causal_model @@ -60,52 +59,55 @@ def parse_args(args: str): def identify_noise_model(causal_mechanism_name: str, parsed_args: dict) -> StochasticModel: - for model_type in _TYPES_OF_STOCHASTIC_MODELS: + for model_type in STOCHASTIC_MODEL_TYPES: if model_type == causal_mechanism_name: - return _TYPES_OF_STOCHASTIC_MODELS[model_type](**parsed_args) - return _TYPES_OF_STOCHASTIC_MODELS['parametric']( + return STOCHASTIC_MODEL_TYPES[model_type](**parsed_args) + return STOCHASTIC_MODEL_TYPES['parametric']( scipy_distribution=getattr(scipy.stats, causal_mechanism_name, None), **parsed_args) class MyCustomModel(PredictionModel): - def __init__(self, custom_func: str): + def __init__(self, custom_func: str, parent_nodes: list): self.custom_func = custom_func + self.parent_nodes = parent_nodes def fit(self, X, Y): # Nothing to fit here, since we know the ground truth. pass def predict(self, X): - return ne.evaluate(custom_func, sanitize=True) + local_dict = {self.parent_nodes[i]: X[:, i] for i in range(len(self.parent_nodes))} + return ne.evaluate(self.custom_func, local_dict=local_dict,sanitize=True) def clone(self): return MyCustomModel(self.custom_func) -def get_noise_model_components(noise_eq): - noise_model_match = re.match(noise_model_pattern, noise_eq) +def extract_noise_model_components(noise_eq): + noise_model_match = re.match(NOISE_MODEL_PATTERN, noise_eq) if noise_model_match: noise_model_name = noise_model_match.group(1) args = noise_model_match.group(2) parsed_args = parse_args(args) return noise_model_name, parsed_args else: - raise InputError("The format of the equation entered should follow : F(X) + N") + raise ValueError("Unable to recognise the format or function specified") -def get_equation_components(equation): +def extract_equation_components(equation): node_name, expression = equation.split("=", 1) node_name = node_name.strip() expression = expression.strip() return node_name, expression -def get_node_pairs(func_equation, child_node): - node_pairs = [] +def extract_parent_nodes(func_equation): + parent_nodes = [] available_funcs = set(dir(__builtins__) + dir(np) + dir(scipy.stats)) # Find all node names in the expression string - parent_nodes = re.findall(parent_node_pattern_for_eq, func_equation) - for parent_node in parent_nodes: - if parent_node not in available_funcs: - node_pairs.append((parent_node, child_node)) - return node_pairs \ No newline at end of file + matched_node_names = re.findall(r'\b[A-Za-z_][A-Za-z_0-9 ]*\b', func_equation) + for matched_node in matched_node_names: + if matched_node not in available_funcs: + parent_nodes.append(matched_node) + parent_nodes.sort() + return parent_nodes \ No newline at end of file