diff --git a/dowhy/causal_graph.py b/dowhy/causal_graph.py index 764e726635..9e9eb9e71f 100755 --- a/dowhy/causal_graph.py +++ b/dowhy/causal_graph.py @@ -4,6 +4,7 @@ import networkx as nx +from dowhy.gcm.causal_models import ProbabilisticCausalModel from dowhy.utils.api import parse_state from dowhy.utils.graph_operations import daggity_to_dot from dowhy.utils.plotting import plot @@ -54,6 +55,10 @@ def __init__( if graph is None: self._graph = nx.DiGraph() self._graph = self.build_graph(common_cause_names, instrument_names, effect_modifier_names, mediator_names) + elif isinstance(graph, nx.DiGraph): + self._graph = nx.DiGraph(graph) + elif isinstance(graph, ProbabilisticCausalModel): + self._graph = nx.DiGraph(graph.graph) elif re.match(r".*\.dot", graph): # load dot file try: @@ -90,9 +95,18 @@ def __init__( elif re.match(".*graph\s*\[.*\]\s*", graph): self._graph = nx.DiGraph(nx.parse_gml(graph)) else: - self.logger.error("Error: Please provide graph (as string or text file) in dot or gml format.") + self.logger.error( + "Error: Please provide graph (as string or text file) in dot, gml format, networkx graph " + "or GCM model." + ) self.logger.error("Error: Incorrect graph format") raise ValueError + + if observed_node_names is None and ( + isinstance(graph, nx.DiGraph) or isinstance(graph, ProbabilisticCausalModel) + ): + observed_node_names = list(self._graph.nodes) + if missing_nodes_as_confounders: self._graph = self.add_missing_nodes_as_common_causes(observed_node_names) # Adding node attributes diff --git a/dowhy/causal_model.py b/dowhy/causal_model.py index b5a22c7c06..7fe43b9527 100755 --- a/dowhy/causal_model.py +++ b/dowhy/causal_model.py @@ -11,7 +11,6 @@ import dowhy.causal_estimators as causal_estimators import dowhy.causal_refuters as causal_refuters import dowhy.graph_learners as graph_learners -import dowhy.utils.cli_helpers as cli from dowhy.causal_estimator import CausalEstimate, estimate_effect from dowhy.causal_graph import CausalGraph from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, IDIdentifier @@ -23,7 +22,6 @@ class CausalModel: - """Main class for storing the causal model state.""" def __init__( diff --git a/dowhy/gcm/causal_models.py b/dowhy/gcm/causal_models.py index 2e0d366544..b6138f990b 100644 --- a/dowhy/gcm/causal_models.py +++ b/dowhy/gcm/causal_models.py @@ -39,8 +39,17 @@ def __init__( :param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph constructor. """ + # Todo: Remove after https://github.com/py-why/dowhy/pull/943. + from dowhy.causal_graph import CausalGraph + from dowhy.causal_model import CausalModel + if graph is None: graph = nx.DiGraph() + elif isinstance(graph, CausalModel): + graph = graph_copier(graph._graph._graph) + elif isinstance(graph, CausalGraph): + graph = graph_copier(graph._graph) + self.graph = graph self.graph_copier = graph_copier diff --git a/tests/test_causal_model.py b/tests/test_causal_model.py index 85bc29ce25..4150b72260 100644 --- a/tests/test_causal_model.py +++ b/tests/test_causal_model.py @@ -1,3 +1,4 @@ +import networkx as nx import pandas as pd import pytest from flaky import flaky @@ -7,6 +8,8 @@ import dowhy import dowhy.datasets from dowhy import CausalModel +from dowhy.causal_graph import CausalGraph +from dowhy.gcm import ProbabilisticCausalModel, StructuralCausalModel class TestCausalModel(object): @@ -498,6 +501,30 @@ def test_unobserved_graph_variables_log_warning(self, caplog): f"Only the following log records were emitted instead: '{caplog.records}'." ) + def test_compability_with_gcm(self): + data = pd.DataFrame({"X": [0], "Y": [0], "Z": [0]}) + model = CausalModel( + data=data, + treatment="Y", + outcome="Z", + graph=StructuralCausalModel(nx.DiGraph([("X", "Y"), ("Y", "Z")])), + ) + + assert set(model._graph._graph.nodes) == {"X", "Y", "Z"} + assert set(model._graph._graph.edges) == {("X", "Y"), ("Y", "Z")} + + causal_graph = CausalGraph("Y", "Z", graph=StructuralCausalModel(nx.DiGraph([("X", "Y"), ("Y", "Z")]))) + assert set(causal_graph._graph.nodes) == {"X", "Y", "Z"} + assert set(causal_graph._graph.edges) == {("X", "Y"), ("Y", "Z")} + + pcm = ProbabilisticCausalModel(model) + assert set(pcm.graph.nodes) == {"X", "Y", "Z"} + assert set(pcm.graph.edges) == {("X", "Y"), ("Y", "Z")} + + pcm = ProbabilisticCausalModel(model._graph) + assert set(pcm.graph.nodes) == {"X", "Y", "Z"} + assert set(pcm.graph.edges) == {("X", "Y"), ("Y", "Z")} + if __name__ == "__main__": pytest.main([__file__])