From bcb4309e0e750c47502d3fa61b57110f78ab3f16 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sun, 8 Sep 2024 11:05:04 -0600 Subject: [PATCH] changed raise to warning, moved warning to low level clone_graph, added doc example, updated pytest --- pymc/model/core.py | 54 +++++++++++++++++++++++++++++----------- pymc/model/fgraph.py | 9 ++++++- tests/model/test_core.py | 34 ++++++++++++++++--------- 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 99f8b2c96c..bd62fa4369 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1574,31 +1574,57 @@ def __contains__(self, key): def __copy__(self): """ Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. - if guassian process variables are detected then an exception will be raised. + Constants are not cloned and if guassian process variables are detected then a warning will be triggered. + + Examples + -------- + .. code-block:: python + + import pymc as pm + import copy + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + + clone_m = copy.copy(m) + + # Access cloned variables by name + clone_x = clone_m["x"] + + # z will be part of clone_m but not m + z = pm.Deterministic("z", clone_x + 1) """ from pymc.model.fgraph import clone_model - check_for_gp_vars = [ - k for x in ["_rotated_", "_hsgp_coeffs_"] for k in self.named_vars.keys() if x in k - ] - if len(check_for_gp_vars) > 0: - raise Exception("Unable to clone Gaussian Process Variables") - return clone_model(self) def __deepcopy__(self, _): """ Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. - if guassian process variables are detected then an exception will be raised. + Constants are not cloned and if guassian process variables are detected then a warning will be triggered. + + Examples + -------- + .. code-block:: python + + import pymc as pm + import copy + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + + clone_m = copy.deepcopy(m) + + # Access cloned variables by name + clone_x = clone_m["x"] + + # z will be part of clone_m but not m + z = pm.Deterministic("z", clone_x + 1) """ from pymc.model.fgraph import clone_model - check_for_gp_vars = [ - k for x in ["_rotated_", "_hsgp_coeffs_"] for k in self.named_vars.keys() if x in k - ] - if len(check_for_gp_vars) > 0: - raise Exception("Unable to clone Gaussian Process Variables") - return clone_model(self) def replace_rvs_by_values( diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index b1d67fd07b..b61ff06da3 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from copy import copy, deepcopy import pytensor @@ -369,7 +371,7 @@ def clone_model(model: Model) -> Model: Recreates a PyMC model with clones of the original variables. Shared variables will point to the same container but be otherwise different objects. - Constants are not cloned. + Constants are not cloned and if guassian process variables are detected then a warning will be triggered. Examples @@ -391,6 +393,11 @@ def clone_model(model: Model) -> Model: z = pm.Deterministic("z", clone_x + 1) """ + check_for_gp_vars = [ + k for x in ["_rotated_", "_hsgp_coeffs_"] for k in model.named_vars.keys() if x in k + ] + if len(check_for_gp_vars) > 0: + warnings.warn("Unable to clone Gaussian Process Variables", UserWarning) return model_from_fgraph(fgraph_from_model(model)[0], mutate_fgraph=True) diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 747040ecdc..eaf711ab39 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -1764,7 +1764,7 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: ) -class TestModelCopy(unittest.TestCase): +class TestModelCopy: @staticmethod def simple_model() -> pm.Model: with pm.Model() as simple_model: @@ -1772,7 +1772,7 @@ def simple_model() -> pm.Model: alpha = pm.Normal("alpha", 0, 1) pm.Normal("y", alpha, error) return simple_model - + @staticmethod def gp_model() -> pm.Model: with pm.Model() as gp_model: @@ -1782,7 +1782,7 @@ def gp_model() -> pm.Model: f = gp.prior("f", X=np.arange(10)[:, None]) pm.Normal("y", f * 2) return gp_model - + def test_copy_model(self) -> None: simple_model = self.simple_model() copy_simple_model = copy.copy(simple_model) @@ -1797,17 +1797,27 @@ def test_copy_model(self) -> None: with deepcopy_simple_model: deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) - simple_model_prior_predictive_mean = simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw')) - copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw')) - deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw')) + simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean( + ("chain", "draw") + ) + copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][ + "y" + ].mean(("chain", "draw")) + deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[ + "prior" + ]["y"].mean(("chain", "draw")) - assert np.isclose(simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean) - assert np.isclose(simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean) + assert np.isclose( + simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean + ) + assert np.isclose( + simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean + ) def test_guassian_process_copy_failure(self) -> None: gaussian_process_model = self.gp_model() - with pytest.raises(Exception) as e: + with pytest.warns(UserWarning): copy.copy(gaussian_process_model) - - with pytest.raises(Exception) as e: - copy.deepcopy(gaussian_process_model) \ No newline at end of file + + with pytest.warns(UserWarning): + copy.deepcopy(gaussian_process_model)