diff --git a/neurodiffeq/pde_spherical.py b/neurodiffeq/pde_spherical.py index b9fc5b2..48879d2 100644 --- a/neurodiffeq/pde_spherical.py +++ b/neurodiffeq/pde_spherical.py @@ -575,6 +575,11 @@ def make_pair_dict(train=None, valid=None): self.lowest_loss = None # local epoch in a `.fit` call, should only be modified inside self.fit() self.local_epoch = 0 + # maximum local epochs to run in a `.fit()` call, should only set by inside self.fit() + self._max_local_epoch = 0 + # controls early stopping, should be set to False at the beginning of a `.fit()` call + # and optionally set to False by `callbacks` in `.fit()` to support early stopping + self._stop_training = False @property def global_epoch(self): @@ -605,6 +610,15 @@ def _auto_enforce(net, cond, r, theta, phi): :rtype: torch.Tensor """ n_params = len(signature(cond.enforce).parameters) + if len(r.shape) != 2 or len(theta.shape) != 2 or len(phi.shape) != 2: + raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)") + + if r.shape[1] != 1 or theta.shape[1] != 1 or phi.shape[1] != 1: + raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)") + + if len(r) != len(theta) or len(r) != len(phi) or len(theta) != len(phi): + raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} differ in dim 0") + if n_params == 2: # noinspection PyArgumentList return cond.enforce(net, r) @@ -757,6 +771,7 @@ def fit(self, max_epochs, monitor=None, callbacks=None): """Run multiple epochs of training and validation, update best loss at the end of each epoch. This method does not return solution, which is done in the `.get_solution` method. If `callbacks` is passed, callbacks are run one at a time, after training, validating, updaing best model and before monitor checking + A callback function `cb(solver)` can set `solver._stop_training` to True to perform early stopping, :param max_epochs: number of epochs to run :type max_epochs: int :param monitor: monitor for visualizing solution and metrics @@ -764,7 +779,14 @@ def fit(self, max_epochs, monitor=None, callbacks=None): :param callbacks: a list of callback functions, each accepting the solver instance itself as its only argument :rtype callbacks: list[callable] """ + self._stop_training = False + self._max_local_epoch = max_epochs + for local_epoch in range(max_epochs): + # stops training if self._stop_training is set to True by a callback + if self._stop_training: + break + # register local epoch so it can be accessed by callbacks self.local_epoch = local_epoch self._resample_train() diff --git a/tests/test_function_basis.py b/tests/test_function_basis.py index 998ca9b..fbb7468 100644 --- a/tests/test_function_basis.py +++ b/tests/test_function_basis.py @@ -46,7 +46,7 @@ def test_legendre_basis(): def test_zero_order_spherical_harmonics(): - # note that in scipy, theta in azimuthal angle (0, 2 pi) while phi is polar angle (0, pi) + # note that in scipy, theta is azimuthal angle (0, 2 pi) while phi is polar angle (0, pi) thetas1 = np.random.rand(*shape) * np.pi * 2 phis1 = np.random.rand(*shape) * np.pi # in neurodiffeq, theta and phi should be exchanged