Skip to content

Commit

Permalink
changed raise to warning, moved warning to low level clone_graph, add…
Browse files Browse the repository at this point in the history
…ed doc example, updated pytest
  • Loading branch information
Dekermanjian committed Sep 9, 2024
1 parent fe4e0c5 commit bcb4309
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 27 deletions.
54 changes: 40 additions & 14 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,31 +1574,57 @@ def __contains__(self, key):
def __copy__(self):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
if guassian process variables are detected then an exception will be raised.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Examples
--------
.. code-block:: python
import pymc as pm
import copy
with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
clone_m = copy.copy(m)
# Access cloned variables by name
clone_x = clone_m["x"]
# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

check_for_gp_vars = [
k for x in ["_rotated_", "_hsgp_coeffs_"] for k in self.named_vars.keys() if x in k
]
if len(check_for_gp_vars) > 0:
raise Exception("Unable to clone Gaussian Process Variables")

return clone_model(self)

def __deepcopy__(self, _):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
if guassian process variables are detected then an exception will be raised.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Examples
--------
.. code-block:: python
import pymc as pm
import copy
with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
clone_m = copy.deepcopy(m)
# Access cloned variables by name
clone_x = clone_m["x"]
# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

check_for_gp_vars = [
k for x in ["_rotated_", "_hsgp_coeffs_"] for k in self.named_vars.keys() if x in k
]
if len(check_for_gp_vars) > 0:
raise Exception("Unable to clone Gaussian Process Variables")

return clone_model(self)

def replace_rvs_by_values(
Expand Down
9 changes: 8 additions & 1 deletion pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from copy import copy, deepcopy

import pytensor
Expand Down Expand Up @@ -369,7 +371,7 @@ def clone_model(model: Model) -> Model:
Recreates a PyMC model with clones of the original variables.
Shared variables will point to the same container but be otherwise different objects.
Constants are not cloned.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Examples
Expand All @@ -391,6 +393,11 @@ def clone_model(model: Model) -> Model:
z = pm.Deterministic("z", clone_x + 1)
"""
check_for_gp_vars = [
k for x in ["_rotated_", "_hsgp_coeffs_"] for k in model.named_vars.keys() if x in k
]
if len(check_for_gp_vars) > 0:
warnings.warn("Unable to clone Gaussian Process Variables", UserWarning)
return model_from_fgraph(fgraph_from_model(model)[0], mutate_fgraph=True)


Expand Down
34 changes: 22 additions & 12 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,15 +1764,15 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
)


class TestModelCopy(unittest.TestCase):
class TestModelCopy:
@staticmethod
def simple_model() -> pm.Model:
with pm.Model() as simple_model:
error = pm.HalfNormal("error", 0.5)
alpha = pm.Normal("alpha", 0, 1)
pm.Normal("y", alpha, error)
return simple_model

@staticmethod
def gp_model() -> pm.Model:
with pm.Model() as gp_model:
Expand All @@ -1782,7 +1782,7 @@ def gp_model() -> pm.Model:
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)
return gp_model

def test_copy_model(self) -> None:
simple_model = self.simple_model()
copy_simple_model = copy.copy(simple_model)
Expand All @@ -1797,17 +1797,27 @@ def test_copy_model(self) -> None:
with deepcopy_simple_model:
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)

simple_model_prior_predictive_mean = simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean(
("chain", "draw")
)
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][
"y"
].mean(("chain", "draw"))
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[
"prior"
]["y"].mean(("chain", "draw"))

assert np.isclose(simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean)
assert np.isclose(simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean)
assert np.isclose(
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean
)
assert np.isclose(
simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean
)

def test_guassian_process_copy_failure(self) -> None:
gaussian_process_model = self.gp_model()
with pytest.raises(Exception) as e:
with pytest.warns(UserWarning):
copy.copy(gaussian_process_model)
with pytest.raises(Exception) as e:
copy.deepcopy(gaussian_process_model)

with pytest.warns(UserWarning):
copy.deepcopy(gaussian_process_model)

0 comments on commit bcb4309

Please sign in to comment.