Skip to content

Commit

Permalink
Delete save and restore methods from the Model class. Add train hooks…
Browse files Browse the repository at this point in the history
… to Arch's training process
  • Loading branch information
aicroe committed Oct 5, 2019
1 parent 843076a commit d671ad2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="MlScratch",
version="0.1.0",
version="0.1.1",
url="https://github.com/aicroe/mlscratch",
project_urls={
"Bug Tracker": "https://github.com/aicroe/mlscratch/issues",
Expand Down
11 changes: 10 additions & 1 deletion src/mlscratch/arch/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@
class Arch(ABC):
"""Abstracts a machine learning instance."""

@abstractmethod
def train_initialize(self) -> None:
"""Train hook. Called once before the training has started."""

@abstractmethod
def train_finalize(self) -> None:
"""Train hook. Called once after the training has finalized."""

@abstractmethod
def update_params(
self,
dataset: Tensor,
labels: Tensor) -> Tuple[float, Tensor]:
"""Updates/Optimizes its trainable parameters."""
"""Updates/Optimizes its trainable parameters.
Called while training to update this instance parameters."""

@abstractmethod
def evaluate(self, dataset: Tensor) -> Tensor:
Expand Down
10 changes: 2 additions & 8 deletions src/mlscratch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def train(
int, List[float], List[float]]:
"""Trains this model's arch."""
train_recorder = _TrainRecorder(train_watcher)
self._arch.train_initialize()
epochs, validation_epochs = trainer.train(
_TrainableArchAdapter(self._arch, measurer),
dataset,
Expand All @@ -102,6 +103,7 @@ def train(
train_recorder,
**options,
)
self._arch.train_finalize()
return (epochs,
train_recorder.costs,
train_recorder.accuracies,
Expand All @@ -122,11 +124,3 @@ def measure(
cost, evaluations = self._arch.check_cost(dataset, labels)
measures = [measurer.measure(evaluations, labels) for measurer in measurers]
return (cost, measures)

def restore(self) -> None:
"""Restors a saved instance."""
self._arch = self._arch.restore()

def save(self) -> None:
"""Saves an instance."""
self._arch.save()
27 changes: 15 additions & 12 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,25 @@ def test_measure_assert_returned_measures(self):

self.assertEqual(results, [56.9, 70.1])

def test_restore_assert_arch_gets_called(self):
def test_train_assert_hooks_are_called(self):
arch = MagicMock()
model = Model(arch)
trainer = MagicMock()
trainer.train.return_value = (None, None)

model.restore()

self.assertEqual(arch.restore.call_count, 1)

def test_save_assert_arch_gets_called(self):
arch = MagicMock()
model = Model(arch)

model.save()

self.assertEqual(arch.save.call_count, 1)
model.train(
None,
None,
None,
None,
trainer,
None,
None,
epochs=0,
)

self.assertEqual(arch.train_initialize.call_count, 1)
self.assertEqual(arch.train_finalize.call_count, 1)

class ModelAndSimpleTrainerIntegrationTest(TestCase):

Expand Down

0 comments on commit d671ad2

Please sign in to comment.