Skip to content

Commit

Permalink
parametrized tests to be more efficient, added test for adding determ…
Browse files Browse the repository at this point in the history
…inistics to clone model, added copy method to Model class
  • Loading branch information
Dekermanjian committed Sep 29, 2024
1 parent 88fde25 commit 90419cb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 57 deletions.
32 changes: 5 additions & 27 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,34 +1572,12 @@ def __contains__(self, key):
return key in self.named_vars or self.name_for(key) in self.named_vars

def __copy__(self):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Examples
--------
.. code-block:: python
import pymc as pm
import copy
with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
clone_m = copy.copy(m)
# Access cloned variables by name
clone_x = clone_m["x"]
# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

return clone_model(self)
return self.copy()

def __deepcopy__(self, _):
return self.copy()

def copy(self):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Expand All @@ -1615,7 +1593,7 @@ def __deepcopy__(self, _):
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
clone_m = copy.deepcopy(m)
clone_m = copy.copy(m)
# Access cloned variables by name
clone_x = clone_m["x"]
Expand Down
61 changes: 31 additions & 30 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,51 +1773,52 @@ def simple_model() -> pm.Model:
pm.Normal("y", alpha, error)
return simple_model

@staticmethod
def gp_model() -> pm.Model:
with pm.Model() as gp_model:
ell = pm.Gamma("ell", alpha=2, beta=1)
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)
return gp_model

def test_copy_model(self) -> None:
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
simple_model = self.simple_model()
copy_simple_model = copy.copy(simple_model)
deepcopy_simple_model = copy.deepcopy(simple_model)
copy_simple_model = copy_method(simple_model)

with simple_model:
simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42)

with copy_simple_model:
copy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)

with deepcopy_simple_model:
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)

simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean(
("chain", "draw")
)
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][
"y"
].mean(("chain", "draw"))
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[
"prior"
]["y"].mean(("chain", "draw"))

assert np.isclose(
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean
)
assert np.isclose(
simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean
)

def test_guassian_process_copy_failure(self) -> None:
gaussian_process_model = self.gp_model()
with pytest.warns(UserWarning):
copy.copy(gaussian_process_model)
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_guassian_process_copy_failure(self, copy_method) -> None:
with pm.Model() as gaussian_process_model:
ell = pm.Gamma("ell", alpha=2, beta=1)
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)

with pytest.warns(
UserWarning,
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
):
copy_method(gaussian_process_model)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_adding_deterministics_to_clone(self, copy_method) -> None:
simple_model = self.simple_model()
clone_model = copy_method(simple_model)

with clone_model:
z = pm.Deterministic("z", clone_model["alpha"] + 1)

with pytest.warns(UserWarning):
copy.deepcopy(gaussian_process_model)
assert "z" in clone_model.named_vars
assert "z" not in simple_model.named_vars

0 comments on commit 90419cb

Please sign in to comment.