Skip to content

Commit

Permalink
Add compatibility between GCM and CausalModel class
Browse files Browse the repository at this point in the history
Before, the causal graphs between the GCM and CausalModel part required some additional work to be converted.
Now, a GCM can be used to initiate a CausalModel or CausalGraph and vice versa.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 21, 2023
1 parent bd4f95f commit 952c6ea
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
16 changes: 15 additions & 1 deletion dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import networkx as nx

from dowhy.gcm.causal_models import ProbabilisticCausalModel
from dowhy.utils.api import parse_state
from dowhy.utils.graph_operations import daggity_to_dot
from dowhy.utils.plotting import plot
Expand Down Expand Up @@ -54,6 +55,10 @@ def __init__(
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names, instrument_names, effect_modifier_names, mediator_names)
elif isinstance(graph, nx.DiGraph):
self._graph = nx.DiGraph(graph)
elif isinstance(graph, ProbabilisticCausalModel):
self._graph = nx.DiGraph(graph.graph)
elif re.match(r".*\.dot", graph):
# load dot file
try:
Expand Down Expand Up @@ -90,9 +95,18 @@ def __init__(
elif re.match(".*graph\s*\[.*\]\s*", graph):
self._graph = nx.DiGraph(nx.parse_gml(graph))
else:
self.logger.error("Error: Please provide graph (as string or text file) in dot or gml format.")
self.logger.error(
"Error: Please provide graph (as string or text file) in dot, gml format, networkx graph "
"or GCM model."
)
self.logger.error("Error: Incorrect graph format")
raise ValueError

if observed_node_names is None and (
isinstance(graph, nx.DiGraph) or isinstance(graph, ProbabilisticCausalModel)
):
observed_node_names = list(self._graph.nodes)

if missing_nodes_as_confounders:
self._graph = self.add_missing_nodes_as_common_causes(observed_node_names)
# Adding node attributes
Expand Down
2 changes: 0 additions & 2 deletions dowhy/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import dowhy.causal_estimators as causal_estimators
import dowhy.causal_refuters as causal_refuters
import dowhy.graph_learners as graph_learners
import dowhy.utils.cli_helpers as cli
from dowhy.causal_estimator import CausalEstimate, estimate_effect
from dowhy.causal_graph import CausalGraph
from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, IDIdentifier
Expand All @@ -23,7 +22,6 @@


class CausalModel:

"""Main class for storing the causal model state."""

def __init__(
Expand Down
9 changes: 9 additions & 0 deletions dowhy/gcm/causal_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,17 @@ def __init__(
:param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph
constructor.
"""
# Todo: Remove after https://github.com/py-why/dowhy/pull/943.
from dowhy.causal_graph import CausalGraph
from dowhy.causal_model import CausalModel

if graph is None:
graph = nx.DiGraph()
elif isinstance(graph, CausalModel):
graph = graph_copier(graph._graph._graph)
elif isinstance(graph, CausalGraph):
graph = graph_copier(graph._graph)

self.graph = graph
self.graph_copier = graph_copier

Expand Down
27 changes: 27 additions & 0 deletions tests/test_causal_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import networkx as nx
import pandas as pd
import pytest
from flaky import flaky
Expand All @@ -7,6 +8,8 @@
import dowhy
import dowhy.datasets
from dowhy import CausalModel
from dowhy.causal_graph import CausalGraph
from dowhy.gcm import ProbabilisticCausalModel, StructuralCausalModel


class TestCausalModel(object):
Expand Down Expand Up @@ -498,6 +501,30 @@ def test_unobserved_graph_variables_log_warning(self, caplog):
f"Only the following log records were emitted instead: '{caplog.records}'."
)

def test_compability_with_gcm(self):
data = pd.DataFrame({"X": [0], "Y": [0], "Z": [0]})
model = CausalModel(
data=data,
treatment="Y",
outcome="Z",
graph=StructuralCausalModel(nx.DiGraph([("X", "Y"), ("Y", "Z")])),
)

assert set(model._graph._graph.nodes) == {"X", "Y", "Z"}
assert set(model._graph._graph.edges) == {("X", "Y"), ("Y", "Z")}

causal_graph = CausalGraph("Y", "Z", graph=StructuralCausalModel(nx.DiGraph([("X", "Y"), ("Y", "Z")])))
assert set(causal_graph._graph.nodes) == {"X", "Y", "Z"}
assert set(causal_graph._graph.edges) == {("X", "Y"), ("Y", "Z")}

pcm = ProbabilisticCausalModel(model)
assert set(pcm.graph.nodes) == {"X", "Y", "Z"}
assert set(pcm.graph.edges) == {("X", "Y"), ("Y", "Z")}

pcm = ProbabilisticCausalModel(model._graph)
assert set(pcm.graph.nodes) == {"X", "Y", "Z"}
assert set(pcm.graph.edges) == {("X", "Y"), ("Y", "Z")}


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 952c6ea

Please sign in to comment.