diff --git a/setup.py b/setup.py index ab16123..b2b5892 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="MlScratch", - version="0.1.0", + version="0.1.1", url="https://github.com/aicroe/mlscratch", project_urls={ "Bug Tracker": "https://github.com/aicroe/mlscratch/issues", diff --git a/src/mlscratch/arch/arch.py b/src/mlscratch/arch/arch.py index 3c92f8c..da50dc1 100644 --- a/src/mlscratch/arch/arch.py +++ b/src/mlscratch/arch/arch.py @@ -8,12 +8,21 @@ class Arch(ABC): """Abstracts a machine learning instance.""" + @abstractmethod + def train_initialize(self) -> None: + """Train hook. Called once before the training has started.""" + + @abstractmethod + def train_finalize(self) -> None: + """Train hook. Called once after the training has finalized.""" + @abstractmethod def update_params( self, dataset: Tensor, labels: Tensor) -> Tuple[float, Tensor]: - """Updates/Optimizes its trainable parameters.""" + """Updates/Optimizes its trainable parameters. + Called while training to update this instance parameters.""" @abstractmethod def evaluate(self, dataset: Tensor) -> Tensor: diff --git a/src/mlscratch/model.py b/src/mlscratch/model.py index 76d6aff..76f81fc 100644 --- a/src/mlscratch/model.py +++ b/src/mlscratch/model.py @@ -93,6 +93,7 @@ def train( int, List[float], List[float]]: """Trains this model's arch.""" train_recorder = _TrainRecorder(train_watcher) + self._arch.train_initialize() epochs, validation_epochs = trainer.train( _TrainableArchAdapter(self._arch, measurer), dataset, @@ -102,6 +103,7 @@ def train( train_recorder, **options, ) + self._arch.train_finalize() return (epochs, train_recorder.costs, train_recorder.accuracies, @@ -122,11 +124,3 @@ def measure( cost, evaluations = self._arch.check_cost(dataset, labels) measures = [measurer.measure(evaluations, labels) for measurer in measurers] return (cost, measures) - - def restore(self) -> None: - """Restors a saved instance.""" - self._arch = self._arch.restore() - - def save(self) -> None: - """Saves an instance.""" - self._arch.save() diff --git a/tests/model_test.py b/tests/model_test.py index a2aaba2..4671b60 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -43,22 +43,25 @@ def test_measure_assert_returned_measures(self): self.assertEqual(results, [56.9, 70.1]) - def test_restore_assert_arch_gets_called(self): + def test_train_assert_hooks_are_called(self): arch = MagicMock() model = Model(arch) + trainer = MagicMock() + trainer.train.return_value = (None, None) - model.restore() - - self.assertEqual(arch.restore.call_count, 1) - - def test_save_assert_arch_gets_called(self): - arch = MagicMock() - model = Model(arch) - - model.save() - - self.assertEqual(arch.save.call_count, 1) + model.train( + None, + None, + None, + None, + trainer, + None, + None, + epochs=0, + ) + self.assertEqual(arch.train_initialize.call_count, 1) + self.assertEqual(arch.train_finalize.call_count, 1) class ModelAndSimpleTrainerIntegrationTest(TestCase):