Skip to content

Commit

Permalink
Merge pull request #10 from Johnny-Wish/topic-early-stopping
Browse files Browse the repository at this point in the history
Topic early stopping
  • Loading branch information
shuheng-liu authored Aug 22, 2020
2 parents 7c647b0 + 0e47fc9 commit 9981313
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
22 changes: 22 additions & 0 deletions neurodiffeq/pde_spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,11 @@ def make_pair_dict(train=None, valid=None):
self.lowest_loss = None
# local epoch in a `.fit` call, should only be modified inside self.fit()
self.local_epoch = 0
# maximum local epochs to run in a `.fit()` call, should only set by inside self.fit()
self._max_local_epoch = 0
# controls early stopping, should be set to False at the beginning of a `.fit()` call
# and optionally set to False by `callbacks` in `.fit()` to support early stopping
self._stop_training = False

@property
def global_epoch(self):
Expand Down Expand Up @@ -605,6 +610,15 @@ def _auto_enforce(net, cond, r, theta, phi):
:rtype: torch.Tensor
"""
n_params = len(signature(cond.enforce).parameters)
if len(r.shape) != 2 or len(theta.shape) != 2 or len(phi.shape) != 2:
raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)")

if r.shape[1] != 1 or theta.shape[1] != 1 or phi.shape[1] != 1:
raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} are not (-1, 1)")

if len(r) != len(theta) or len(r) != len(phi) or len(theta) != len(phi):
raise ValueError(f"{r.shape}, {theta.shape}, or {phi.shape} differ in dim 0")

if n_params == 2:
# noinspection PyArgumentList
return cond.enforce(net, r)
Expand Down Expand Up @@ -757,14 +771,22 @@ def fit(self, max_epochs, monitor=None, callbacks=None):
"""Run multiple epochs of training and validation, update best loss at the end of each epoch.
This method does not return solution, which is done in the `.get_solution` method.
If `callbacks` is passed, callbacks are run one at a time, after training, validating, updaing best model and before monitor checking
A callback function `cb(solver)` can set `solver._stop_training` to True to perform early stopping,
:param max_epochs: number of epochs to run
:type max_epochs: int
:param monitor: monitor for visualizing solution and metrics
:rtype monitor: `neurodiffeq.pde_spherical.MonitorSpherical`
:param callbacks: a list of callback functions, each accepting the solver instance itself as its only argument
:rtype callbacks: list[callable]
"""
self._stop_training = False
self._max_local_epoch = max_epochs

for local_epoch in range(max_epochs):
# stops training if self._stop_training is set to True by a callback
if self._stop_training:
break

# register local epoch so it can be accessed by callbacks
self.local_epoch = local_epoch
self._resample_train()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_legendre_basis():


def test_zero_order_spherical_harmonics():
# note that in scipy, theta in azimuthal angle (0, 2 pi) while phi is polar angle (0, pi)
# note that in scipy, theta is azimuthal angle (0, 2 pi) while phi is polar angle (0, pi)
thetas1 = np.random.rand(*shape) * np.pi * 2
phis1 = np.random.rand(*shape) * np.pi
# in neurodiffeq, theta and phi should be exchanged
Expand Down

0 comments on commit 9981313

Please sign in to comment.