diff --git a/neurodiffeq/pde_spherical.py b/neurodiffeq/pde_spherical.py index f87cb91..48879d2 100644 --- a/neurodiffeq/pde_spherical.py +++ b/neurodiffeq/pde_spherical.py @@ -610,6 +610,15 @@ def _auto_enforce(net, cond, r, theta, phi): :rtype: torch.Tensor """ n_params = len(signature(cond.enforce).parameters) + if len(r.shape) != 2 or len(theta.shape) != 2 or len(phi.shape) != 2: + raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)") + + if r.shape[1] != 1 or theta.shape[1] != 1 or phi.shape[1] != 1: + raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)") + + if len(r) != len(theta) or len(r) != len(phi) or len(theta) != len(phi): + raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} differ in dim 0") + if n_params == 2: # noinspection PyArgumentList return cond.enforce(net, r) diff --git a/tests/test_function_basis.py b/tests/test_function_basis.py index 05e97d8..77a9c23 100644 --- a/tests/test_function_basis.py +++ b/tests/test_function_basis.py @@ -46,7 +46,7 @@ def test_legendre_basis(): def test_zero_order_spherical_harmonics(): - # note that in scipy, theta in azimuthal angle (0, 2 pi) while phi is polar angle (0, pi) + # note that in scipy, theta is azimuthal angle (0, 2 pi) while phi is polar angle (0, pi) thetas1 = np.random.rand(*shape) * np.pi * 2 phis1 = np.random.rand(*shape) * np.pi # in neurodiffeq, theta and phi should be exchanged