Skip to content

Commit

Permalink
corrected some logic and refactored, modularized
Browse files Browse the repository at this point in the history
  • Loading branch information
bhatt-priyadutt committed Nov 7, 2023
1 parent b94a3ec commit 9b1d042
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions dowhy/gcm/causal_graph_from_eq.py → dowhy/gcm/equation_parser.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import re
import ast
import networkx as nx
import numexpr as ne
import numpy as np
import scipy.stats
from dowhy.gcm.causal_mechanisms import StochasticModel
from dowhy.gcm import EmpiricalDistribution, ScipyDistribution, StructuralCausalModel, AdditiveNoiseModel
from dowhy.gcm.ml.prediction_model import PredictionModel

_TYPES_OF_STOCHASTIC_MODELS = {"empirical": EmpiricalDistribution, "bayesiangaussianmixture": EmpiricalDistribution,
STOCHASTIC_MODEL_TYPES = {"empirical": EmpiricalDistribution, "bayesiangaussianmixture": EmpiricalDistribution,
"parametric": ScipyDistribution}
stocastic_function_names = f'({"|".join(list(_TYPES_OF_STOCHASTIC_MODELS.keys()))})'
parent_node_pattern_for_eq = r'\b[A-Za-z_][A-Za-z_0-9 ]*\b'
# root_node_pattern = rf'^\s*(.+)\s*=\s*({stocastic_function_names})\(([^)]*)\)\s*$'
noise_model_pattern = rf'^\s*([\w]+)\(([^)]*)\)\s*$'
NOISE_MODEL_PATTERN = rf'^\s*([\w]+)\(([^)]*)\)\s*$'


def create_causal_model_from_eq(node_equations: str):
def create_causal_model_from_equations(node_equations: str):
graph_node_pairs = []
causal_nodes_info = {}
for equation in node_equations.split('\n'):
equation = equation.strip()
if equation:
node_name, expression = get_equation_components(equation)
node_name, expression = extract_equation_components(equation)
if not (node_name in causal_nodes_info):
causal_nodes_info[node_name] = {}
print("Variable Name:", node_name)
root_node_match = re.match(noise_model_pattern, expression)
root_node_match = re.match(NOISE_MODEL_PATTERN, expression)
if root_node_match:
causal_mechanism_name = root_node_match.group(1)
print(causal_mechanism_name)
Expand All @@ -35,16 +33,17 @@ def create_causal_model_from_eq(node_equations: str):
parsed_args)
else:
custom_func, noise_eq = expression.rsplit('+', 1)
graph_node_pairs += get_node_pairs(custom_func, node_name)
noise_model_name, parsed_args = get_noise_model_components(noise_eq)
parent_nodes = extract_parent_nodes(custom_func)
graph_node_pairs += [(parent_node, node_name) for parent_node in parent_nodes]
noise_model_name, parsed_args = extract_noise_model_components(noise_eq)
noise_model = identify_noise_model(noise_model_name, parsed_args)
causal_nodes_info[node_name]['causal_mechanism'] = AdditiveNoiseModel(MyCustomModel(custom_func),
causal_nodes_info[node_name]['causal_mechanism'] = AdditiveNoiseModel(MyCustomModel(custom_func, parent_nodes),
noise_model)

causal_graph = nx.DiGraph(graph_node_pairs)
causal_model = StructuralCausalModel(causal_graph)
for node in causal_graph.nodes:
causal_model.set_causal_mechanism(node, causal_nodes_info[node_name]['causal_mechanism'])
causal_model.set_causal_mechanism(node, causal_nodes_info[node]['causal_mechanism'])
return causal_model


Expand All @@ -60,52 +59,55 @@ def parse_args(args: str):


def identify_noise_model(causal_mechanism_name: str, parsed_args: dict) -> StochasticModel:
for model_type in _TYPES_OF_STOCHASTIC_MODELS:
for model_type in STOCHASTIC_MODEL_TYPES:
if model_type == causal_mechanism_name:
return _TYPES_OF_STOCHASTIC_MODELS[model_type](**parsed_args)
return _TYPES_OF_STOCHASTIC_MODELS['parametric'](
return STOCHASTIC_MODEL_TYPES[model_type](**parsed_args)
return STOCHASTIC_MODEL_TYPES['parametric'](
scipy_distribution=getattr(scipy.stats, causal_mechanism_name, None), **parsed_args)


class MyCustomModel(PredictionModel):
def __init__(self, custom_func: str):
def __init__(self, custom_func: str, parent_nodes: list):
self.custom_func = custom_func
self.parent_nodes = parent_nodes

def fit(self, X, Y):
# Nothing to fit here, since we know the ground truth.
pass

def predict(self, X):
return ne.evaluate(custom_func, sanitize=True)
local_dict = {self.parent_nodes[i]: X[:, i] for i in range(len(self.parent_nodes))}
return ne.evaluate(self.custom_func, local_dict=local_dict,sanitize=True)

def clone(self):
return MyCustomModel(self.custom_func)


def get_noise_model_components(noise_eq):
noise_model_match = re.match(noise_model_pattern, noise_eq)
def extract_noise_model_components(noise_eq):
noise_model_match = re.match(NOISE_MODEL_PATTERN, noise_eq)
if noise_model_match:
noise_model_name = noise_model_match.group(1)
args = noise_model_match.group(2)
parsed_args = parse_args(args)
return noise_model_name, parsed_args
else:
raise InputError("The format of the equation entered should follow : F(X) + N")
raise ValueError("Unable to recognise the format or function specified")


def get_equation_components(equation):
def extract_equation_components(equation):
node_name, expression = equation.split("=", 1)
node_name = node_name.strip()
expression = expression.strip()
return node_name, expression


def get_node_pairs(func_equation, child_node):
node_pairs = []
def extract_parent_nodes(func_equation):
parent_nodes = []
available_funcs = set(dir(__builtins__) + dir(np) + dir(scipy.stats))
# Find all node names in the expression string
parent_nodes = re.findall(parent_node_pattern_for_eq, func_equation)
for parent_node in parent_nodes:
if parent_node not in available_funcs:
node_pairs.append((parent_node, child_node))
return node_pairs
matched_node_names = re.findall(r'\b[A-Za-z_][A-Za-z_0-9 ]*\b', func_equation)
for matched_node in matched_node_names:
if matched_node not in available_funcs:
parent_nodes.append(matched_node)
parent_nodes.sort()
return parent_nodes

0 comments on commit 9b1d042

Please sign in to comment.