Skip to content

Commit

Permalink
Create save load n random scms (#4)
Browse files Browse the repository at this point in the history
* generate N SCMs at once

* saving/loading of sets fo SCMs
  • Loading branch information
sa-and authored May 21, 2024
1 parent b2d1777 commit 24083d6
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 10 deletions.
59 changes: 59 additions & 0 deletions CausalPlayground/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import networkx as nx
from tqdm import tqdm
import pickle
import dill
import warnings
from CausalPlayground import StructuralCausalModel

Expand Down Expand Up @@ -57,6 +58,28 @@ def create_random(self, possible_functions: List[str], n_endo: int, n_exo: int,
return self.create_scm_from_graph(graph, possible_functions, exo_distribution, exo_distribution_kwargs),\
removed_edges

def create_n_random(self, N: int, possible_functions: List[str], n_endo: int, n_exo: int, exo_distribution: Callable = None,
exo_distribution_kwargs: dict = None, allow_exo_confounders: bool = False)\
-> List[StructuralCausalModel]:
"""
Creates and returns N random StructualCausalModel by calling the create_random function.
:param N: number of SCMs to generate
:param possible_functions: list of function that can be used as causal relations. These should correspond to one
of the strings defined in `self.all_functions`.
:param n_endo: number of endogenous variables.
:param n_exo: number of exogenous variables.
:param exo_distribution: distribution of the exogenous variables. This distribution is applied to all
exogenous variables.
:param exo_distribution_kwargs: keyword arguments for the distribution of exogenous variables
:param allow_exo_distribution: true if exogenous confounders should be generated.
:return: the list of random scms
"""
scms = [self.create_random(possible_functions, n_endo, n_exo, exo_distribution,
exo_distribution_kwargs, allow_exo_confounders)[0] for _ in range(N)]
return scms

def create_scm_from_graph(self, graph: nx.DiGraph, possible_functions: List[str], exo_distribution: Callable,
exo_distribution_kwargs: dict) -> StructuralCausalModel:
"""
Expand Down Expand Up @@ -85,6 +108,42 @@ def create_scm_from_graph(self, graph: nx.DiGraph, possible_functions: List[str]
self.seed += 1
return scm

@staticmethod
def save_scms(scms: List[StructuralCausalModel], filepath: str, verbose: int = 0) -> None:
"""
Helper function for saving a list of StructuralCausalModels
:param scms: list of SCMs for saving
:param filepath: path to file for saving the SCMs
:param verbose: verbosity level, 0=silent, 1=print to console
"""
try:
with open(filepath, 'wb') as file:
dill.dump(scms, file)
message = f"SCM successfully saved to {filepath}"
except IOError as e:
message = f"Error saving object to file: {str(e)}"
if verbose == 1:
print(message)

@staticmethod
def load(filepath: str, verbose: int = 0) -> "StructuralCausalModel":
"""
Helper function for loading SCMs from a file
:param filepath: The path and filename of the file to load the object from.
:param verbose: verbosity level, 0=silent, 1=print to console
:return: the loaded SCMs
"""
try:
with open(filepath, 'rb') as file:
obj = dill.load(file)
return obj
except IOError as e:
if verbose == 1:
print(f"Error loading object from file: {str(e)}")
except dill.UnpicklingError as e:
if verbose == 1:
print(f"Error deserializing object: {str(e)}")


class CausalGraphGenerator:
"""Class for generating random directed acyclic graphs."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "causal-playground"
version = "0.1.1"
version = "0.1.2"
requires-python = ">= 3.10"
authors = [{name = "Andreas Sauter", email = "a.sauter@vu.nl"}]
readme = "README.md"
Expand Down
30 changes: 21 additions & 9 deletions tests/test_scm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@


class TestSCM(unittest.TestCase):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs):
super(TestSCM, self).__init__(*args, **kwargs)
self.scm = SCMGenerator(all_functions={'linear': f_linear}, seed=42).create_random(possible_functions=["linear"], n_endo=20, n_exo=0)[0]
self.scm = \
SCMGenerator(all_functions={'linear': f_linear}, seed=42).create_random(possible_functions=["linear"],
n_endo=20, n_exo=0)[0]
self.test_scm = StructuralCausalModel()
self.test_scm.add_endogenous_var('A', lambda noise: noise + 5, {'noise': 'U'})
self.test_scm.add_exogenous_var('U', random.randint, {'a': 3, 'b': 8})
Expand All @@ -22,8 +24,8 @@ def test_creation(self):
a = sample[0]['A']
effect = sample[0]['EFFECT']
self.assertIn(u, [3, 4, 5, 6, 7, 8])
self.assertEqual(a, u+5)
self.assertEqual(effect, a*2)
self.assertEqual(a, u + 5)
self.assertEqual(effect, a * 2)

def test_create_from_graph(self):
graph = self.test_scm.create_graph()
Expand All @@ -36,12 +38,12 @@ def test_create_from_graph(self):
a = sample[0]['A']
effect = sample[0]['EFFECT']
self.assertIn(u, [2, 3, 4, 5])
self.assertTrue(effect <= 2*a)
self.assertTrue(effect <= 2 * a)

def test_intervention(self):
# do an intervention and compare before and after
x0 = self.scm.get_next_sample()[0]['X0']
self.scm.do_interventions([("X0", (lambda: 5, {})), ("X1", (lambda x0: x0+1, {'x0': 'X0'}))])
self.scm.do_interventions([("X0", (lambda: 5, {})), ("X1", (lambda x0: x0 + 1, {'x0': 'X0'}))])
x0_do = self.scm.get_next_sample()[0]['X0']
x1_do = self.scm.get_next_sample()[0]['X1']
self.assertTrue(x0_do == 5)
Expand Down Expand Up @@ -90,11 +92,20 @@ def test_save_and_load(self):
if os.path.exists("./delme.pkl"):
os.remove("./delme.pkl")

def test_create_n_random(self):
generator = SCMGenerator(all_functions={'linear': f_linear}, seed=1)
scms = generator.create_n_random(10, ['linear'], 3, 4,
random.random, {}, False)
self.assertTrue(len(scms) == 10)
self.assertTrue(type(scms[0]) == StructuralCausalModel)


class TestRandNN(unittest.TestCase):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs):
super(TestRandNN, self).__init__(*args, **kwargs)
self.scm1 = SCMGenerator(all_functions={'linear': f_linear}, seed=42).create_random(possible_functions=["linear"], n_endo=10, n_exo=0)[0]
self.scm1 = \
SCMGenerator(all_functions={'linear': f_linear}, seed=42).create_random(possible_functions=["linear"],
n_endo=10, n_exo=0)[0]
# self.scm2 = SCMGenerator(seed=42).create_random(possible_functions=["NN"], n_endo=6, n_exo=8,
# exo_distribution=random.random, exo_distribution_kwargs={})[0]

Expand Down Expand Up @@ -157,3 +168,4 @@ def test_sampling(self):

if __name__ == '__main__':
unittest.main()

0 comments on commit 24083d6

Please sign in to comment.