Skip to content

Commit

Permalink
enhance(solver): checks the shape r, theta, and phi
Browse files Browse the repository at this point in the history
  • Loading branch information
shuheng-liu committed Aug 22, 2020
1 parent 45fac60 commit 0e47fc9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions neurodiffeq/pde_spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,15 @@ def _auto_enforce(net, cond, r, theta, phi):
:rtype: torch.Tensor
"""
n_params = len(signature(cond.enforce).parameters)
if len(r.shape) != 2 or len(theta.shape) != 2 or len(phi.shape) != 2:
raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)")

if r.shape[1] != 1 or theta.shape[1] != 1 or phi.shape[1] != 1:
raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)")

if len(r) != len(theta) or len(r) != len(phi) or len(theta) != len(phi):
raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} differ in dim 0")

if n_params == 2:
# noinspection PyArgumentList
return cond.enforce(net, r)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_legendre_basis():


def test_zero_order_spherical_harmonics():
# note that in scipy, theta in azimuthal angle (0, 2 pi) while phi is polar angle (0, pi)
# note that in scipy, theta is azimuthal angle (0, 2 pi) while phi is polar angle (0, pi)
thetas1 = np.random.rand(*shape) * np.pi * 2
phis1 = np.random.rand(*shape) * np.pi
# in neurodiffeq, theta and phi should be exchanged
Expand Down

0 comments on commit 0e47fc9

Please sign in to comment.