Skip to content

Commit

Permalink
implement SCM saving and loading (#3)
Browse files Browse the repository at this point in the history
* saving and loading of SCMs

* bump version
  • Loading branch information
sa-and authored Apr 5, 2024
1 parent df294dd commit b2d1777
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 17 deletions.
39 changes: 38 additions & 1 deletion CausalPlayground/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import networkx as nx
import matplotlib.pyplot as plt
import random
import dill


class StructuralCausalModel:
Expand All @@ -19,7 +20,7 @@ class StructuralCausalModel:
"""

endogenous_vars: Dict[str, Any]
"""Dictionary storing the value of each endogenoous variable."""
"""Dictionary storing the value of each endogenous variable."""
functions: Dict[str, Tuple[Callable, dict]]
"""Functional assignments of the endogenous variables. for each endogenous variables, the functional assignments are
stored as well as a dictionary mapping the parameters of the callable (key) to the name of the causes (values)."""
Expand Down Expand Up @@ -197,3 +198,39 @@ def draw_graph(self):
labels={key: str(key) + ':\n' + str(values[key]) for key in values}, pos=nx.planar_layout(graph))
plt.show()

def save(self, filepath: str, verbose: int = 0) -> None:
"""
Save the structural causal model to a file
:param filepath: path to file for saving the SCM
:param verbose: verbosity level, 0=silent, 1=print to console
"""
try:
with open(filepath, 'wb') as file:
dill.dump(self, file)
message = f"SCM successfully saved to {filepath}"
except IOError as e:
message = f"Error saving object to file: {str(e)}"
if verbose == 1:
print(message)

@staticmethod
def load(filepath: str, verbose: int = 0) -> "StructuralCausalModel":
"""
Load an object from a file using dill deserialization.
:param filepath: The path and filename of the file to load the object from.
:param verbose: verbosity level, 0=silent, 1=print to console
:return: the loaded SCM
"""
try:
with open(filepath, 'rb') as file:
obj = dill.load(file)
return obj
except IOError as e:
if verbose == 1:
print(f"Error loading object from file: {str(e)}")
except dill.UnpicklingError as e:
if verbose == 1:
print(f"Error deserializing object: {str(e)}")
6 changes: 3 additions & 3 deletions docs/CausalPlayground/scm_environment.html
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ <h1 class="modulename">
</span><span id="L-125"><a href="#L-125"><span class="linenos">125</span></a><span class="sd"> :param options: Additional information to specify how the environment is reset</span>
</span><span id="L-126"><a href="#L-126"><span class="linenos">126</span></a><span class="sd"> :return: The current observation and additional information.</span>
</span><span id="L-127"><a href="#L-127"><span class="linenos">127</span></a><span class="sd"> &quot;&quot;&quot;</span>
</span><span id="L-128"><a href="#L-128"><span class="linenos">128</span></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">)</span>
</span><span id="L-128"><a href="#L-128"><span class="linenos">128</span></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> <span class="n">options</span><span class="o">=</span><span class="n">options</span><span class="p">)</span>
</span><span id="L-129"><a href="#L-129"><span class="linenos">129</span></a> <span class="bp">self</span><span class="o">.</span><span class="n">steps_this_episode</span> <span class="o">=</span> <span class="mi">0</span>
</span><span id="L-130"><a href="#L-130"><span class="linenos">130</span></a>
</span><span id="L-131"><a href="#L-131"><span class="linenos">131</span></a> <span class="bp">self</span><span class="o">.</span><span class="n">update_values_from_scm_sample</span><span class="p">()</span>
Expand Down Expand Up @@ -507,7 +507,7 @@ <h1 class="modulename">
</span><span id="SCMEnvironment-126"><a href="#SCMEnvironment-126"><span class="linenos">126</span></a><span class="sd"> :param options: Additional information to specify how the environment is reset</span>
</span><span id="SCMEnvironment-127"><a href="#SCMEnvironment-127"><span class="linenos">127</span></a><span class="sd"> :return: The current observation and additional information.</span>
</span><span id="SCMEnvironment-128"><a href="#SCMEnvironment-128"><span class="linenos">128</span></a><span class="sd"> &quot;&quot;&quot;</span>
</span><span id="SCMEnvironment-129"><a href="#SCMEnvironment-129"><span class="linenos">129</span></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">)</span>
</span><span id="SCMEnvironment-129"><a href="#SCMEnvironment-129"><span class="linenos">129</span></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> <span class="n">options</span><span class="o">=</span><span class="n">options</span><span class="p">)</span>
</span><span id="SCMEnvironment-130"><a href="#SCMEnvironment-130"><span class="linenos">130</span></a> <span class="bp">self</span><span class="o">.</span><span class="n">steps_this_episode</span> <span class="o">=</span> <span class="mi">0</span>
</span><span id="SCMEnvironment-131"><a href="#SCMEnvironment-131"><span class="linenos">131</span></a>
</span><span id="SCMEnvironment-132"><a href="#SCMEnvironment-132"><span class="linenos">132</span></a> <span class="bp">self</span><span class="o">.</span><span class="n">update_values_from_scm_sample</span><span class="p">()</span>
Expand Down Expand Up @@ -903,7 +903,7 @@ <h6 id="returns">Returns</h6>
</span><span id="SCMEnvironment.reset-126"><a href="#SCMEnvironment.reset-126"><span class="linenos">126</span></a><span class="sd"> :param options: Additional information to specify how the environment is reset</span>
</span><span id="SCMEnvironment.reset-127"><a href="#SCMEnvironment.reset-127"><span class="linenos">127</span></a><span class="sd"> :return: The current observation and additional information.</span>
</span><span id="SCMEnvironment.reset-128"><a href="#SCMEnvironment.reset-128"><span class="linenos">128</span></a><span class="sd"> &quot;&quot;&quot;</span>
</span><span id="SCMEnvironment.reset-129"><a href="#SCMEnvironment.reset-129"><span class="linenos">129</span></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">)</span>
</span><span id="SCMEnvironment.reset-129"><a href="#SCMEnvironment.reset-129"><span class="linenos">129</span></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> <span class="n">options</span><span class="o">=</span><span class="n">options</span><span class="p">)</span>
</span><span id="SCMEnvironment.reset-130"><a href="#SCMEnvironment.reset-130"><span class="linenos">130</span></a> <span class="bp">self</span><span class="o">.</span><span class="n">steps_this_episode</span> <span class="o">=</span> <span class="mi">0</span>
</span><span id="SCMEnvironment.reset-131"><a href="#SCMEnvironment.reset-131"><span class="linenos">131</span></a>
</span><span id="SCMEnvironment.reset-132"><a href="#SCMEnvironment.reset-132"><span class="linenos">132</span></a> <span class="bp">self</span><span class="o">.</span><span class="n">update_values_from_scm_sample</span><span class="p">()</span>
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "causal-playground"
version = "0.1.0"
version = "0.1.1"
requires-python = ">= 3.10"
authors = [{name = "Andreas Sauter", email = "a.sauter@vu.nl"}]
readme = "README.md"
Expand Down
20 changes: 8 additions & 12 deletions tests/test_graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ def test_seeding(self):
# these specific seeds and 50 samples

def test_graphset_save_load(self):
if os.path.exists("./delme.pkl"):
os.remove("./delme.pkl")
graph_set = CausalGraphSetGenerator(n_endo=5, n_exo=10, allow_exo_confounders=False)
graph_set.generate(20)
graph_set.save('delme.pkl')
graph_set.load('delme.pkl')
graph_set.save('./delme.pkl')
graph_set.load('./delme.pkl')
self.assertTrue(len(graph_set.graphs), 20)
os.remove('delme.pkl')
np.array([1, 2]).dump('delme.pkl')
with self.assertRaises(TypeError):
graph_set.load('delme.pkl')
os.remove('./delme.pkl')

def test_unique_graph_set(self):
graph_set = CausalGraphSetGenerator(n_endo=4, n_exo=0, allow_exo_confounders=False)
Expand All @@ -79,15 +84,6 @@ def test_unique_graph_set(self):
edges = [list(g.edges()) for g in graph_set.graphs]
self.assertTrue(all([len(e) == len(set(e)) for e in edges]))

def test_save_load_graph_set(self):
graph_set = CausalGraphSetGenerator(n_endo=4, n_exo=0, allow_exo_confounders=False)
graph_set.generate(300)
graph_set.save('delme.pkl')
graph_set.load('delme.pkl')
np.array([1, 2]).dump('delme.pkl')
with self.assertRaises(TypeError):
graph_set.load('delme.pkl')


if __name__ == '__main__':
unittest.main()
18 changes: 18 additions & 0 deletions tests/test_scm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest
from functions import *
import random
import os


class TestSCM(unittest.TestCase):
Expand Down Expand Up @@ -72,6 +73,23 @@ def test_intervention_targets(self):
self.scm.undo_interventions()
self.assertTrue(self.scm.get_intervention_targets() == [])

def test_save_and_load(self):
if os.path.exists("./delme.pkl"):
os.remove("./delme.pkl")
# Test saving the object to a file
self.test_scm.save("./delme.pkl")
self.assertTrue(os.path.exists("./delme.pkl"))

# Test loading the object from the file
loaded_scm = StructuralCausalModel.load("./delme.pkl")
self.assertIsInstance(loaded_scm, StructuralCausalModel)
self.assertEqual(list(loaded_scm.endogenous_vars.keys()), ['A', 'EFFECT'])
self.assertEqual(list(loaded_scm.exogenous_vars.keys()), ['U'])
self.assertEqual(loaded_scm.create_graph().edges, self.test_scm.create_graph().edges)

if os.path.exists("./delme.pkl"):
os.remove("./delme.pkl")


class TestRandNN(unittest.TestCase):
def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit b2d1777

Please sign in to comment.