diff --git a/docs/source/example_notebooks/dowhy_causal_discovery_example.ipynb b/docs/source/example_notebooks/dowhy_causal_discovery_example.ipynb index f1f3096095..22e5b8a53a 100644 --- a/docs/source/example_notebooks/dowhy_causal_discovery_example.ipynb +++ b/docs/source/example_notebooks/dowhy_causal_discovery_example.ipynb @@ -6,12 +6,12 @@ "source": [ "# Causal Discovery example\n", "\n", - "The goal of this notebook is to show how causal discovery methods can work with DoWhy. We use discovery methods from [Causal Discovery Tool (CDT)](https://github.com/FenTechSolutions/CausalDiscoveryToolbox) repo. As we will see, causal discovery methods are not fool-proof and there is no guarantee that they will recover the correct causal graph. Even for the simple examples below, there is a large variance in results. These methods, however, may be combined usefully with domain knowledge to construct the final causal graph." + "The goal of this notebook is to show how causal discovery methods can work with DoWhy. We use discovery methods from [causal-learn](https://github.com/py-why/causal-learn) repo. As we will see, causal discovery methods require appropriate assumptions for the correctness guarantees, adn thus there will be variance across results returned by different methods in practice. These methods, however, may be combined usefully with domain knowledge to construct the final causal graph." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -67,7 +67,7 @@ "source": [ "# Experiments on the Auto-MPG dataset\n", "\n", - "In this section, we will use a dataset on the technical specification of cars. The dataset is downloaded from UCI Machine Learning Repository. The dataset contains 9 attributes and 398 instances. We do not know the true causal graph for the dataset and will use CDT to discover it. The causal graph obtained will then be used to estimate the causal effect.\n" + "In this section, we will use a dataset on the technical specification of cars. The dataset is downloaded from UCI Machine Learning Repository. The dataset contains 9 attributes and 398 instances. We do not know the true causal graph for the dataset and will use causal-learn to discover it. The causal graph obtained will then be used to estimate the causal effect.\n" ] }, { @@ -79,9 +79,109 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(392, 6)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
mpgcylindersdisplacementhorsepowerweightacceleration
018.08.0307.0130.03504.012.0
115.08.0350.0165.03693.011.5
218.08.0318.0150.03436.011.0
316.08.0304.0150.03433.012.0
417.08.0302.0140.03449.010.5
\n", + "
" + ], + "text/plain": [ + " mpg cylinders displacement horsepower weight acceleration\n", + "0 18.0 8.0 307.0 130.0 3504.0 12.0\n", + "1 15.0 8.0 350.0 165.0 3693.0 11.5\n", + "2 18.0 8.0 318.0 150.0 3436.0 11.0\n", + "3 16.0 8.0 304.0 150.0 3433.0 12.0\n", + "4 17.0 8.0 302.0 140.0 3449.0 10.5" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "data_mpg = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data-original',\n", " delim_whitespace=True, header=None,\n", @@ -98,354 +198,992 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Causal Discovery with Causal Discovery Tool (CDT)\n", + "# Causal Discovery with causal-learn\n", "\n", - "We use the CDT library to perform causal discovery on the Auto-MPG dataset. We use three methods for causal discovery here -LiNGAM, PC and GES. These methods are widely used and do not take much time to run. Hence, these are ideal for an introduction to the topic. Other neural network based methods are also available in CDT and the users are encouraged to try them out by themselves. \n", + "We use the causal-learn library to perform causal discovery on the Auto-MPG dataset. We use three methods for causal discovery here: PC, FCI and GES. These methods are widely used and do not take much time to run. Hence, these are ideal for an introduction to the topic. Causal-learn provides a comprehensive list of well-tested causal-discovery methods, and readers are welcome to explore.\n", "\n", "The documentation for the methods used are as follows:\n", - "- LiNGAM [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/LiNGAM.html)\n", - "- PC [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/PC.html)\n", - "- GES [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/GES.html)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cdt.causality.graph import LiNGAM, PC, GES\n", + "- PC [[link]](https://causal-learn.readthedocs.io/en/latest/search_methods_index/Constraint-based%20causal%20discovery%20methods/PC.html)\n", + "- GES [[link]](https://causal-learn.readthedocs.io/en/latest/search_methods_index/Score-based%20causal%20discovery%20methods/GES.html)\n", + "- LiNGAM [[link]](https://causal-learn.readthedocs.io/en/latest/search_methods_index/Causal%20discovery%20methods%20based%20on%20constrained%20functional%20causal%20models/lingam.html#ica-based-lingam)\n", "\n", - "graphs = {}\n", - "labels = [f'{col}' for i, col in enumerate(data_mpg.columns)]\n", - "functions = {\n", - " 'LiNGAM' : LiNGAM,\n", - " 'PC' : PC,\n", - " 'GES' : GES,\n", - "}\n", - "\n", - "for method, lib in functions.items():\n", - " obj = lib()\n", - " output = obj.predict(data_mpg)\n", - " adj_matrix = nx.to_numpy_array(output)\n", - " adj_matrix = np.asarray(adj_matrix)\n", - " graph_dot = make_graph(adj_matrix, labels)\n", - " graphs[method] = graph_dot\n", - "\n", - "# Visualize graphs\n", - "for method, graph in graphs.items():\n", - " print(\"Method : %s\"%(method))\n", - " display(graph)" + "More methods could be found in the causal-learn documentation [[link]](https://causal-learn.readthedocs.io/en/latest/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As you can see, no two methods agree on the graphs. PC and GES effectively produce an undirected graph whereas LiNGAM produces a directed graph. We use only the LiNGAM method in the next section." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Estimate causal effects using Linear Regression\n", - "\n", - "Now let us see whether these differences in the graphs also lead to significant differences in the causal estimate of effect of *mpg* on *weight*." + "We first try the PC algorithm with default parameters." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1ed197e9f5ec42c8bf7fc51c5ece4485", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/6 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "for method, graph in graphs.items():\n", - " if method != \"LiNGAM\":\n", - " continue\n", - " print('\\n*****************************************************************************\\n')\n", - " print(\"Causal Discovery Method : %s\"%(method))\n", - " \n", - " # Obtain valid dot format\n", - " graph_dot = str_to_dot(graph.source)\n", - "\n", - " # Define Causal Model\n", - " model=CausalModel(\n", - " data = data_mpg,\n", - " treatment='mpg',\n", - " outcome='weight',\n", - " graph=graph_dot)\n", - "\n", - " # Identification\n", - " identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n", - " print(identified_estimand)\n", - " \n", - " # Estimation\n", - " estimate = model.estimate_effect(identified_estimand,\n", - " method_name=\"backdoor.linear_regression\",\n", - " control_value=0,\n", - " treatment_value=1,\n", - " confidence_intervals=True,\n", - " test_significance=True)\n", - " print(\"Causal Estimate is \" + str(estimate.value))" + "from causallearn.search.ConstraintBased.PC import pc\n", + "\n", + "labels = [f'{col}' for i, col in enumerate(data_mpg.columns)]\n", + "data = data_mpg.to_numpy()\n", + "\n", + "cg = pc(data)\n", + "\n", + "# Visualization using pydot\n", + "from causallearn.utils.GraphUtils import GraphUtils\n", + "import matplotlib.image as mpimg\n", + "import matplotlib.pyplot as plt\n", + "import io\n", + "\n", + "pyd = GraphUtils.to_pydot(cg.G, labels=labels)\n", + "tmp_png = pyd.create_png(f=\"png\")\n", + "fp = io.BytesIO(tmp_png)\n", + "img = mpimg.imread(fp, format='png')\n", + "plt.axis('off')\n", + "plt.imshow(img)\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As mentioned earlier, due to the absence of directed edges, no backdoor, instrmental or frontdoor variables can be found out for PC and GES. Thus, causal effect estimation is not possible for these methods. However, LiNGAM does discover a DAG and hence, its possible to output a causal estimate for LiNGAM. The estimate is still pretty far from the original estimate of -70.466 (which can be calculated from the graph)." + "Then we have a causal graph discovered by PC. Let us also try GES to see its result." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 5, "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# Experiments on the Sachs dataset\n", + "from causallearn.search.ScoreBased.GES import ges\n", "\n", - "The dataset consists of the simultaneous measurements of 11 phosphorylated proteins and phospholipids derived from thousands of individual primary immune system cells, subjected to both general and specific molecular interventions (Sachs et al., 2005).\n", + "# default parameters\n", + "Record = ges(data)\n", "\n", - "The specifications of the dataset are as follows - \n", - "- Number of nodes: 11\n", - "- Number of arcs: 17\n", - "- Number of parameters: 178\n", - "- Average Markov blanket size: 3.09\n", - "- Average degree: 3.09\n", - "- Maximum in-degree: 3\n", - "- Number of instances: 7466\n", + "# Visualization using pydot\n", + "from causallearn.utils.GraphUtils import GraphUtils\n", + "import matplotlib.image as mpimg\n", + "import matplotlib.pyplot as plt\n", + "import io\n", "\n", - "The original causal graph is known for the Sachs dataset and we compare the original graph with the ones discovered using CDT in this section." + "pyd = GraphUtils.to_pydot(Record['G'], labels=labels)\n", + "tmp_png = pyd.create_png(f=\"png\")\n", + "fp = io.BytesIO(tmp_png)\n", + "img = mpimg.imread(fp, format='png')\n", + "plt.axis('off')\n", + "plt.imshow(img)\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Load the data" + "Well, these two results are different, which is not rare when applying causal discovery on real-world dataset, since the required assumptions on the data-generating process are hard to verify.\n", + "\n", + "In addition, the graphs returned by PC and GES are CPDAGs instead of DAGs, so it is possible to have undirected edges (e.g., the result returned by GES). Thus, causal effect estimataion is difficult for those methods, since there may be absence of backdoor, instrumental or frontdoor variables. In order to get a DAG, we decide to try LiNGAM on our dataset." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "mpg\n", + "\n", + "mpg\n", + "\n", + "\n", + "\n", + "displacement\n", + "\n", + "displacement\n", + "\n", + "\n", + "\n", + "mpg->displacement\n", + "\n", + "\n", + "-0.64\n", + "\n", + "\n", + "\n", + "horsepower\n", + "\n", + "horsepower\n", + "\n", + "\n", + "\n", + "mpg->horsepower\n", + "\n", + "\n", + "-1.40\n", + "\n", + "\n", + "\n", + "weight\n", + "\n", + "weight\n", + "\n", + "\n", + "\n", + "mpg->weight\n", + "\n", + "\n", + "-17.70\n", + "\n", + "\n", + "\n", + "cylinders\n", + "\n", + "cylinders\n", + "\n", + "\n", + "\n", + "cylinders->mpg\n", + "\n", + "\n", + "-3.55\n", + "\n", + "\n", + "\n", + "cylinders->displacement\n", + "\n", + "\n", + "40.12\n", + "\n", + "\n", + "\n", + "cylinders->horsepower\n", + "\n", + "\n", + "10.14\n", + "\n", + "\n", + "\n", + "acceleration\n", + "\n", + "acceleration\n", + "\n", + "\n", + "\n", + "cylinders->acceleration\n", + "\n", + "\n", + "-0.82\n", + "\n", + "\n", + "\n", + "displacement->weight\n", + "\n", + "\n", + "5.24\n", + "\n", + "\n", + "\n", + "horsepower->displacement\n", + "\n", + "\n", + "0.83\n", + "\n", + "\n", + "\n", + "horsepower->weight\n", + "\n", + "\n", + "6.49\n", + "\n", + "\n", + "\n", + "acceleration->horsepower\n", + "\n", + "\n", + "-4.77\n", + "\n", + "\n", + "\n", + "acceleration->weight\n", + "\n", + "\n", + "61.92\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from cdt.data import load_dataset\n", - "data_sachs, graph_sachs = load_dataset(\"sachs\")\n", + "from causallearn.search.FCMBased import lingam\n", + "model = lingam.ICALiNGAM()\n", + "model.fit(data)\n", "\n", - "data_sachs.dropna(inplace=True)\n", - "print(data_sachs.shape)\n", - "data_sachs.head()" + "from causallearn.search.FCMBased.lingam.utils import make_dot\n", + "make_dot(model.adjacency_matrix_, labels=labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Ground truth of the causal graph" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "labels = [f'{col}' for i, col in enumerate(data_sachs.columns)]\n", - "adj_matrix = nx.to_numpy_array(graph_sachs)\n", - "adj_matrix = np.asarray(adj_matrix)\n", - "graph_dot = make_graph(adj_matrix, labels)\n", - "display(graph_dot)" + "Now we have a DAG and are ready to estimate the causal effects based on that." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Causal Discovery with Causal Discovery Tool (CDT)\n", - "\n", - "We use the CDT library to perform causal discovery on the Auto-MPG dataset. We use three methods for causal discovery here -LiNGAM, PC and GES. These methods are widely used and do not take much time to run. Hence, these are ideal for an introduction to the topic. Other neural network based methods are also available in CDT and the users the encourages to try them out by themselves. \n", + "## Estimate causal effects using Linear Regression\n", "\n", - "The documentation for the methods used in as follows:\n", - "- LiNGAM [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/LiNGAM.html)\n", - "- PC [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/PC.html)\n", - "- GES [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/GES.html)" + "Now let us see the estimate of causal effect of *mpg* on *weight*." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Estimand type: EstimandType.NONPARAMETRIC_ATE\n", + "\n", + "### Estimand : 1\n", + "Estimand name: backdoor\n", + "Estimand expression:\n", + " d \n", + "──────(E[weight|cylinders])\n", + "d[mpg] \n", + "Estimand assumption 1, Unconfoundedness: If U→{mpg} and U→weight then P(weight|mpg,cylinders,U) = P(weight|mpg,cylinders)\n", + "\n", + "### Estimand : 2\n", + "Estimand name: iv\n", + "No such variable(s) found!\n", + "\n", + "### Estimand : 3\n", + "Estimand name: frontdoor\n", + "No such variable(s) found!\n", + "\n", + "Causal Estimate is -38.940973656209735\n" + ] + } + ], "source": [ - "from cdt.causality.graph import LiNGAM, PC, GES\n", + "# Obtain valid dot format\n", + "graph_dot = make_graph(model.adjacency_matrix_, labels=labels)\n", "\n", - "graphs = {}\n", - "graphs_nx = {}\n", - "labels = [f'{col}' for i, col in enumerate(data_sachs.columns)]\n", - "functions = {\n", - " 'LiNGAM' : LiNGAM,\n", - " 'PC' : PC,\n", - " 'GES' : GES,\n", - "}\n", - "\n", - "for method, lib in functions.items():\n", - " obj = lib()\n", - " output = obj.predict(data_sachs)\n", - " graphs_nx[method] = output\n", - " adj_matrix = nx.to_numpy_array(output)\n", - " adj_matrix = np.asarray(adj_matrix)\n", - " graph_dot = make_graph(adj_matrix, labels)\n", - " graphs[method] = graph_dot\n", - "\n", - "# Visualize graphs\n", - "for method, graph in graphs.items():\n", - " print(\"Method : %s\"%(method))\n", - " display(graph)" + "# Define Causal Model\n", + "model=CausalModel(\n", + " data = data_mpg,\n", + " treatment='mpg',\n", + " outcome='weight',\n", + " graph=str_to_dot(graph_dot.source))\n", + "\n", + "# Identification\n", + "identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n", + "print(identified_estimand)\n", + "\n", + "# Estimation\n", + "estimate = model.estimate_effect(identified_estimand,\n", + " method_name=\"backdoor.linear_regression\",\n", + " control_value=0,\n", + " treatment_value=1,\n", + " confidence_intervals=True,\n", + " test_significance=True)\n", + "print(\"Causal Estimate is \" + str(estimate.value))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As you can see, no two methods agree on the graphs. Next we study the causal effects of these different graphs" + "# Experiments on the Sachs dataset\n", + "\n", + "The dataset consists of the simultaneous measurements of 11 phosphorylated proteins and phospholipids derived from thousands of individual primary immune system cells, subjected to both general and specific molecular interventions (Sachs et al., 2005).\n", + "\n", + "The specifications of the dataset are as follows - \n", + "- Number of nodes: 11\n", + "- Number of arcs: 17\n", + "- Number of parameters: 178\n", + "- Average Markov blanket size: 3.09\n", + "- Average degree: 3.09\n", + "- Maximum in-degree: 3\n", + "- Number of instances: 7466" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Estimate effects using Linear Regression\n", - "\n", - "Now let us see whether these differences in the graphs also lead to significant differences in the causal estimate of effect of *PIP2* on *PKC*." + "## 1. Load the data" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(7466, 11)\n", + "['raf', 'mek', 'plc', 'pip2', 'pip3', 'erk', 'akt', 'pka', 'pkc', 'p38', 'jnk']\n" + ] + } + ], "source": [ - "for method, graph in graphs.items():\n", - " if method != \"LiNGAM\":\n", - " continue\n", - " print('\\n*****************************************************************************\\n')\n", - " print(\"Causal Discovery Method : %s\"%(method))\n", - "\n", - " # Obtain valid dot format\n", - " graph_dot = str_to_dot(graph.source)\n", - "\n", - " # Define Causal Model\n", - " model=CausalModel(\n", - " data = data_sachs,\n", - " treatment='PIP2',\n", - " outcome='PKC',\n", - " graph=graph_dot)\n", - "\n", - " # Identification\n", - " identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n", - " print(identified_estimand)\n", - "\n", - " # Estimation\n", - " estimate = model.estimate_effect(identified_estimand,\n", - " method_name=\"backdoor.linear_regression\",\n", - " control_value=0,\n", - " treatment_value=1,\n", - " confidence_intervals=True,\n", - " test_significance=True)\n", - " print(\"Causal Estimate is \" + str(estimate.value))" + "from causallearn.utils.Dataset import load_dataset\n", + "\n", + "data_sachs, labels = load_dataset(\"sachs\")\n", + "\n", + "print(data.shape)\n", + "print(labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "From the causal estimates obtained, it can be seen that the three estimates differ in different aspects. The graph obtained using LiNGAM contains a backdoor path and instrumental variables. On the other hand, the graph obtained using PC contains a backdoor path and a frontdoor path. However, despite these differences, both obtain the same mean causal estimate.\n", + "# Causal Discovery with causal-learn\n", "\n", - "The graph obtained using GES contains only a backdoor path with different backdoor variables and obtains a different causal estimate than the first two cases. " + "We use the three causal discovery methods mentioned above (PC, GES, and LiNGAM) to find the causal graphs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Graph Validation\n", - "\n", - "We compare the graphs obtained with the true causal graph using the causal discovery methods using 2 graph distance metrics - Structural Hamming Distance (SHD) and Structural Intervention Distance (SID). SHD between two graphs is, in simple terms, the number of edge insertions, deletions or flips in order to transform one graph to another graph. SID, on the other hand, is based on a graphical criterion only and quantifies the closeness between two DAGs in terms of their corresponding causal inference statements." + "First, let us take a look at how PC works." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bc0f31d1492e4934994a6d4ba68f1ad3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/11 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "from cdt.metrics import SHD, SHD_CPDAG, SID, SID_CPDAG\n", - "from numpy.random import randint\n", - "\n", - "for method, graph in graphs_nx.items():\n", - " print(\"***********************************************************\")\n", - " print(\"Method: %s\"%(method))\n", - " tar, pred = graph_sachs, graph\n", - " print(\"SHD_CPDAG = %f\"%(SHD_CPDAG(tar, pred)))\n", - " print(\"SHD = %f\"%(SHD(tar, pred, double_for_anticausal=False)))\n", - " print(\"SID_CPDAG = [%f, %f]\"%(SID_CPDAG(tar, pred)))\n", - " print(\"SID = %f\"%(SID(tar, pred)))" + "graphs = {}\n", + "graphs_nx = {}\n", + "labels = [f'{col}' for i, col in enumerate(labels)]\n", + "data = data_sachs\n", + "\n", + "from causallearn.search.ConstraintBased.PC import pc\n", + "\n", + "cg = pc(data)\n", + "\n", + "# Visualization using pydot\n", + "from causallearn.utils.GraphUtils import GraphUtils\n", + "import matplotlib.image as mpimg\n", + "import matplotlib.pyplot as plt\n", + "import io\n", + "\n", + "pyd = GraphUtils.to_pydot(cg.G, labels=labels)\n", + "tmp_png = pyd.create_png(f=\"png\")\n", + "fp = io.BytesIO(tmp_png)\n", + "img = mpimg.imread(fp, format='png')\n", + "plt.axis('off')\n", + "plt.imshow(img)\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The graph similarity metrics show that the scores are the lowest for the LiNGAM method of graph extraction. Hence, of the three methods used, LiNGAM provides the graph that is most similar to the original graph." + "Then, let us try GES." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 16, "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "## Graph Refutation\n", + "from causallearn.search.ScoreBased.GES import ges\n", + "\n", + "# default parameters\n", + "Record = ges(data)\n", "\n", - "Here, we use the same SHD and SID metric to find out how different the discovered graph are from each other." + "# Visualization using pydot\n", + "from causallearn.utils.GraphUtils import GraphUtils\n", + "import matplotlib.image as mpimg\n", + "import matplotlib.pyplot as plt\n", + "import io\n", + "\n", + "pyd = GraphUtils.to_pydot(Record['G'], labels=labels)\n", + "tmp_png = pyd.create_png(f=\"png\")\n", + "fp = io.BytesIO(tmp_png)\n", + "img = mpimg.imread(fp, format='png')\n", + "plt.axis('off')\n", + "plt.imshow(img)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And also LiNGAM." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "raf\n", + "\n", + "raf\n", + "\n", + "\n", + "\n", + "mek\n", + "\n", + "mek\n", + "\n", + "\n", + "\n", + "raf->mek\n", + "\n", + "\n", + "1.48\n", + "\n", + "\n", + "\n", + "pka\n", + "\n", + "pka\n", + "\n", + "\n", + "\n", + "raf->pka\n", + "\n", + "\n", + "0.55\n", + "\n", + "\n", + "\n", + "pkc\n", + "\n", + "pkc\n", + "\n", + "\n", + "\n", + "raf->pkc\n", + "\n", + "\n", + "-0.13\n", + "\n", + "\n", + "\n", + "jnk\n", + "\n", + "jnk\n", + "\n", + "\n", + "\n", + "raf->jnk\n", + "\n", + "\n", + "-0.02\n", + "\n", + "\n", + "\n", + "mek->pka\n", + "\n", + "\n", + "-0.50\n", + "\n", + "\n", + "\n", + "mek->pkc\n", + "\n", + "\n", + "0.10\n", + "\n", + "\n", + "\n", + "p38\n", + "\n", + "p38\n", + "\n", + "\n", + "\n", + "mek->p38\n", + "\n", + "\n", + "0.03\n", + "\n", + "\n", + "\n", + "plc\n", + "\n", + "plc\n", + "\n", + "\n", + "\n", + "plc->raf\n", + "\n", + "\n", + "0.14\n", + "\n", + "\n", + "\n", + "plc->mek\n", + "\n", + "\n", + "0.04\n", + "\n", + "\n", + "\n", + "pip2\n", + "\n", + "pip2\n", + "\n", + "\n", + "\n", + "plc->pip2\n", + "\n", + "\n", + "1.58\n", + "\n", + "\n", + "\n", + "akt\n", + "\n", + "akt\n", + "\n", + "\n", + "\n", + "plc->akt\n", + "\n", + "\n", + "0.28\n", + "\n", + "\n", + "\n", + "plc->pka\n", + "\n", + "\n", + "-0.49\n", + "\n", + "\n", + "\n", + "plc->pkc\n", + "\n", + "\n", + "0.05\n", + "\n", + "\n", + "\n", + "plc->p38\n", + "\n", + "\n", + "0.06\n", + "\n", + "\n", + "\n", + "plc->jnk\n", + "\n", + "\n", + "0.10\n", + "\n", + "\n", + "\n", + "pip2->pkc\n", + "\n", + "\n", + "0.03\n", + "\n", + "\n", + "\n", + "pip3\n", + "\n", + "pip3\n", + "\n", + "\n", + "\n", + "pip3->mek\n", + "\n", + "\n", + "-0.06\n", + "\n", + "\n", + "\n", + "pip3->plc\n", + "\n", + "\n", + "0.37\n", + "\n", + "\n", + "\n", + "pip3->pip2\n", + "\n", + "\n", + "0.80\n", + "\n", + "\n", + "\n", + "pip3->akt\n", + "\n", + "\n", + "-0.17\n", + "\n", + "\n", + "\n", + "pip3->pkc\n", + "\n", + "\n", + "-0.10\n", + "\n", + "\n", + "\n", + "pip3->jnk\n", + "\n", + "\n", + "-0.05\n", + "\n", + "\n", + "\n", + "erk\n", + "\n", + "erk\n", + "\n", + "\n", + "\n", + "erk->raf\n", + "\n", + "\n", + "-1.47\n", + "\n", + "\n", + "\n", + "erk->mek\n", + "\n", + "\n", + "-0.24\n", + "\n", + "\n", + "\n", + "erk->plc\n", + "\n", + "\n", + "0.59\n", + "\n", + "\n", + "\n", + "erk->akt\n", + "\n", + "\n", + "1.90\n", + "\n", + "\n", + "\n", + "erk->pka\n", + "\n", + "\n", + "4.81\n", + "\n", + "\n", + "\n", + "erk->pkc\n", + "\n", + "\n", + "-0.33\n", + "\n", + "\n", + "\n", + "erk->p38\n", + "\n", + "\n", + "-0.16\n", + "\n", + "\n", + "\n", + "erk->jnk\n", + "\n", + "\n", + "-0.29\n", + "\n", + "\n", + "\n", + "akt->raf\n", + "\n", + "\n", + "0.75\n", + "\n", + "\n", + "\n", + "akt->mek\n", + "\n", + "\n", + "0.15\n", + "\n", + "\n", + "\n", + "akt->pka\n", + "\n", + "\n", + "-0.58\n", + "\n", + "\n", + "\n", + "akt->pkc\n", + "\n", + "\n", + "0.25\n", + "\n", + "\n", + "\n", + "akt->p38\n", + "\n", + "\n", + "0.15\n", + "\n", + "\n", + "\n", + "akt->jnk\n", + "\n", + "\n", + "0.27\n", + "\n", + "\n", + "\n", + "pka->p38\n", + "\n", + "\n", + "-0.02\n", + "\n", + "\n", + "\n", + "pkc->pka\n", + "\n", + "\n", + "-0.59\n", + "\n", + "\n", + "\n", + "pkc->p38\n", + "\n", + "\n", + "4.95\n", + "\n", + "\n", + "\n", + "pkc->jnk\n", + "\n", + "\n", + "1.47\n", + "\n", + "\n", + "\n", + "p38->jnk\n", + "\n", + "\n", + "0.04\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import itertools\n", - "from numpy.random import randint\n", - "from cdt.metrics import SHD, SHD_CPDAG, SID, SID_CPDAG\n", - "\n", - "# Find combinations of pair of methods to compare\n", - "combinations = list(itertools.combinations(graphs_nx, 2))\n", - "\n", - "for pair in combinations:\n", - " print(\"***********************************************************\")\n", - " graph1 = graphs_nx[pair[0]]\n", - " graph2 = graphs_nx[pair[1]]\n", - " print(\"Methods: %s and %s\"%(pair[0], pair[1]))\n", - " print(\"SHD_CPDAG = %f\"%(SHD_CPDAG(graph1, graph2)))\n", - " print(\"SHD = %f\"%(SHD(graph1, graph2, double_for_anticausal=False)))\n", - " print(\"SID_CPDAG = [%f, %f]\"%(SID_CPDAG(graph1, graph2)))\n", - " print(\"SID = %f\"%(SID(graph1, graph2)))" + "from causallearn.search.FCMBased import lingam\n", + "model = lingam.ICALiNGAM()\n", + "model.fit(data)\n", + "\n", + "from causallearn.search.FCMBased.lingam.utils import make_dot\n", + "make_dot(model.adjacency_matrix_, labels=labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The values for the metrics show how different the graphs are from each other. A higher distance value implies that the difference between the graphs is more." + "## Estimate effects using Linear Regression\n", + "\n", + "Similarly, let us use the DAG returned by LiNGAM to estimate the causal effect of *PIP2* on *PKC*." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Estimand type: EstimandType.NONPARAMETRIC_ATE\n", + "\n", + "### Estimand : 1\n", + "Estimand name: backdoor\n", + "Estimand expression:\n", + " d \n", + "───────(E[pkc|plc,pip3])\n", + "d[pip₂] \n", + "Estimand assumption 1, Unconfoundedness: If U→{pip2} and U→pkc then P(pkc|pip2,plc,pip3,U) = P(pkc|pip2,plc,pip3)\n", + "\n", + "### Estimand : 2\n", + "Estimand name: iv\n", + "No such variable(s) found!\n", + "\n", + "### Estimand : 3\n", + "Estimand name: frontdoor\n", + "No such variable(s) found!\n", + "\n", + "Causal Estimate is 0.03397189228452291\n" + ] + } + ], + "source": [ + "# Obtain valid dot format\n", + "graph_dot = make_graph(model.adjacency_matrix_, labels=labels)\n", + "\n", + "data_df = pd.DataFrame(data=data, columns=labels)\n", + "\n", + "# Define Causal Model\n", + "model_est=CausalModel(\n", + " data = data_df,\n", + " treatment='pip2',\n", + " outcome='pkc',\n", + " graph=str_to_dot(graph_dot.source))\n", + "\n", + "# Identification\n", + "identified_estimand = model_est.identify_effect(proceed_when_unidentifiable=False)\n", + "print(identified_estimand)\n", + "\n", + "# Estimation\n", + "estimate = model_est.estimate_effect(identified_estimand,\n", + " method_name=\"backdoor.linear_regression\",\n", + " control_value=0,\n", + " treatment_value=1,\n", + " confidence_intervals=True,\n", + " test_significance=True)\n", + "print(\"Causal Estimate is \" + str(estimate.value))" + ] } ], "metadata": { @@ -464,7 +1202,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.17" }, "metadata": { "interpreter": {