Skip to content
This repository has been archived by the owner on Jun 22, 2022. It is now read-only.

Commit

Permalink
Dev segmentation (#36)
Browse files Browse the repository at this point in the history
* added loaders

* fix pytorch callbacks

* prepare for v0.1.9, dropped not used tests

* Update requirements.txt
  • Loading branch information
Kamil A. Kaczmarek authored Sep 20, 2018
1 parent 8cdd18a commit 78a1a09
Show file tree
Hide file tree
Showing 10 changed files with 650 additions and 99 deletions.
10 changes: 3 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
steppy==0.1.4
neptune-cli==2.8.5
attrdict==2.0.0
numpy==1.14.3
pandas==0.23.0
pytest==3.6.0
setuptools==39.2.0
neptune-cli>=2.8.0
setuptools>=39.2.0
steppy>=0.1.9
14 changes: 5 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@

setup(name='steppy-toolkit',
packages=find_packages(),
version='0.1.8',
version='0.1.9',
description='Set of tools to make your work with steppy faster and more effective.',
long_description=long_description,
url='https://github.com/minerva-ml/steppy-toolkit',
download_url='https://github.com/minerva-ml/steppy-toolkit/archive/0.1.8.tar.gz',
download_url='https://github.com/minerva-ml/steppy-toolkit/archive/0.1.9.tar.gz',
author='Kamil A. Kaczmarek, Jakub Czakon',
author_email='kamil.kaczmarek@neptune.ml, jakub.czakon@neptune.ml',
keywords=['machine-learning', 'reproducibility', 'pipeline', 'tools'],
license='MIT',
install_requires=[
'steppy>=0.1.4',
'neptune-cli>=2.8.5',
'attrdict>=2.0.0',
'numpy>=1.14.0',
'pandas>=0.23.0',
'pytest>=3.6.0',
'setuptools>=39.2.0'],
'neptune-cli>=2.8.0',
'setuptools>=39.2.0',
'steppy>=0.1.9'],
zip_safe=False,
classifiers=[])
Empty file removed tests/sklearn/__init__.py
Empty file.
66 changes: 0 additions & 66 deletions tests/sklearn/test_models.py

This file was deleted.

36 changes: 20 additions & 16 deletions toolkit/pytorch_transformers/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ def on_epoch_begin(self, *args, **kwargs):
def on_epoch_end(self, *args, **kwargs):
self.epoch_id += 1

def training_break(self, *args, **kwargs):
return False

def on_batch_begin(self, *args, **kwargs):
pass

def on_batch_end(self, *args, **kwargs):
self.batch_id += 1

def training_break(self, *args, **kwargs):
return False

def get_validation_loss(self):
if self.validation_loss is None:
self.validation_loss = {}
return self.validation_loss.setdefault(self.epoch_id, score_model(self.model,
self.loss_function,
self.validation_datagen))
if self.epoch_id not in self.validation_loss.keys():
self.validation_loss[self.epoch_id] = score_model(self.model,
self.loss_function,
self.validation_datagen)
return self.validation_loss[self.epoch_id]


class CallbackList:
Expand Down Expand Up @@ -93,10 +93,6 @@ def on_epoch_end(self, *args, **kwargs):
for callback in self.callbacks:
callback.on_epoch_end(*args, **kwargs)

def training_break(self, *args, **kwargs):
callback_out = [callback.training_break(*args, **kwargs) for callback in self.callbacks]
return any(callback_out)

def on_batch_begin(self, *args, **kwargs):
for callback in self.callbacks:
callback.on_batch_begin(*args, **kwargs)
Expand All @@ -105,6 +101,10 @@ def on_batch_end(self, *args, **kwargs):
for callback in self.callbacks:
callback.on_batch_end(*args, **kwargs)

def training_break(self, *args, **kwargs):
callback_out = [callback.training_break(*args, **kwargs) for callback in self.callbacks]
return any(callback_out)


class TrainingMonitor(Callback):
def __init__(self, epoch_every=None, batch_every=None):
Expand Down Expand Up @@ -176,8 +176,9 @@ def __init__(self, patience, minimize=True):
self.minimize = minimize
self.best_score = None
self.epoch_since_best = 0
self._training_break = False

def training_break(self, *args, **kwargs):
def on_epoch_end(self, *args, **kwargs):
self.model.eval()
val_loss = self.get_validation_loss()
loss_sum = val_loss['sum']
Expand All @@ -195,9 +196,12 @@ def training_break(self, *args, **kwargs):
self.epoch_since_best += 1

if self.epoch_since_best > self.patience:
return True
else:
return False
self._training_break = True

self.epoch_id += 1

def training_break(self, *args, **kwargs):
return self._training_break


class ExponentialLRScheduler(Callback):
Expand Down
File renamed without changes.
Loading

0 comments on commit 78a1a09

Please sign in to comment.