From 7e79bfa4a0adf7b3bf1a70ba48b2856644a52c4a Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 4 Jul 2023 14:46:32 +0530 Subject: [PATCH] Guidelines --- tests/logprob/test_checks.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index 7356bd0bb14..a4e72cda614 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -33,7 +33,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import re import numpy as np import pytensor @@ -44,7 +43,7 @@ from scipy import stats from pymc.distributions import Dirichlet -from pymc.logprob.basic import conditional_logp +from pymc.logprob.joint_logprob import factorized_joint_logprob from tests.distributions.test_multivariate import dirichlet_logpdf @@ -59,7 +58,7 @@ def test_specify_shape_logprob(): # 2. Request logp x_vv = x_rv.clone() - [x_logp] = conditional_logp({x_rv: x_vv}).values() + [x_logp] = factorized_joint_logprob({x_rv: x_vv}).values() # 3. Test logp x_logp_fn = pytensor.function([last_dim, x_vv], x_logp) @@ -81,19 +80,17 @@ def test_assert_logprob(): rv = pt.random.normal() assert_op = Assert("Test assert") # Example: Add assert that rv must be positive - assert_rv = assert_op(rv, rv > 0) + assert_rv = assert_op(rv > 0, rv) assert_rv.name = "assert_rv" assert_vv = assert_rv.clone() - assert_logp = conditional_logp({assert_rv: assert_vv})[assert_vv] + assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv] # Check valid value is correct and doesn't raise # Since here the value to the rv satisfies the condition, no error is raised. valid_value = 3.0 - np.testing.assert_allclose( - assert_logp.eval({assert_vv: valid_value}), - stats.norm.logpdf(valid_value), - ) + with pytest.raises(AssertionError, match="Test assert"): + assert_logp.eval({assert_vv: valid_value}) # Check invalid value # Since here the value to the rv is negative, an exception is raised as the condition is not met