From 07106ec26860a156525db0350291401fb61a6fb4 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Mon, 30 Sep 2024 06:31:41 -0600 Subject: [PATCH] updated copy method docs and simplified TestModelCopy tests --- pymc/model/core.py | 5 +++-- tests/model/test_core.py | 42 +++++++++++++++++----------------------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 9235524157..0d3ad555ed 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1579,8 +1579,9 @@ def __deepcopy__(self, _): def copy(self): """ - Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. - Constants are not cloned and if guassian process variables are detected then a warning will be triggered. + Clone the model + + To access variables in the cloned model use `cloned_model["var_name"]`. Examples -------- diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 61e502d6cd..6678af3b17 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -1765,17 +1765,13 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: class TestModelCopy: - @staticmethod - def simple_model() -> pm.Model: + @pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) + def test_copy_model(self, copy_method) -> None: with pm.Model() as simple_model: error = pm.HalfNormal("error", 0.5) alpha = pm.Normal("alpha", 0, 1) pm.Normal("y", alpha, error) - return simple_model - @pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) - def test_copy_model(self, copy_method) -> None: - simple_model = self.simple_model() copy_simple_model = copy_method(simple_model) with simple_model: @@ -1786,15 +1782,24 @@ def test_copy_model(self, copy_method) -> None: samples=1, random_seed=42 ) - simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean( - ("chain", "draw") - ) - copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][ + simple_model_prior_predictive_val = simple_model_prior_predictive["prior"]["y"].values + copy_simple_model_prior_predictive_val = copy_simple_model_prior_predictive["prior"][ "y" - ].mean(("chain", "draw")) + ].values - assert np.isclose( - simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean + assert simple_model_prior_predictive_val == copy_simple_model_prior_predictive_val + + with copy_simple_model: + z = pm.Deterministic("z", copy_simple_model["alpha"] + 1) + copy_simple_model_prior_predictive = pm.sample_prior_predictive( + samples=1, random_seed=42 + ) + + assert "z" in copy_simple_model.named_vars + assert "z" not in simple_model.named_vars + assert ( + copy_simple_model_prior_predictive["prior"]["z"].values + == 1 + copy_simple_model_prior_predictive["prior"]["alpha"].values ) @pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) @@ -1811,14 +1816,3 @@ def test_guassian_process_copy_failure(self, copy_method) -> None: match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883", ): copy_method(gaussian_process_model) - - @pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) - def test_adding_deterministics_to_clone(self, copy_method) -> None: - simple_model = self.simple_model() - clone_model = copy_method(simple_model) - - with clone_model: - z = pm.Deterministic("z", clone_model["alpha"] + 1) - - assert "z" in clone_model.named_vars - assert "z" not in simple_model.named_vars