Skip to content

Commit 45fac60

Browse files
committed
feat(monitor): register max_local_epochs to solver for callbacks to use
1 parent 84f1df8 commit 45fac60

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

neurodiffeq/pde_spherical.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ def make_pair_dict(train=None, valid=None):
575575
self.lowest_loss = None
576576
# local epoch in a `.fit` call, should only be modified inside self.fit()
577577
self.local_epoch = 0
578+
# maximum local epochs to run in a `.fit()` call, should only set by inside self.fit()
579+
self._max_local_epoch = 0
578580
# controls early stopping, should be set to False at the beginning of a `.fit()` call
579581
# and optionally set to False by `callbacks` in `.fit()` to support early stopping
580582
self._stop_training = False
@@ -769,6 +771,7 @@ def fit(self, max_epochs, monitor=None, callbacks=None):
769771
:rtype callbacks: list[callable]
770772
"""
771773
self._stop_training = False
774+
self._max_local_epoch = max_epochs
772775

773776
for local_epoch in range(max_epochs):
774777
# stops training if self._stop_training is set to True by a callback

0 commit comments

Comments
 (0)