Skip to content

Commit

Permalink
added tf scalar logging
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jun 29, 2019
1 parent 3e38dca commit 1998dec
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions test_tube/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
debug=False,
version=None,
save_dir=None,
autosave=True,
autosave=False,
description=None,
create_git_tag=False,
*args, **kwargs
Expand All @@ -34,7 +34,6 @@ def __init__(
:param name:
:param debug:
"""
super(SummaryWriter, self).__init__(*args, **kwargs)

# change where the save dir is if requested
if save_dir is not None:
Expand All @@ -53,6 +52,8 @@ def __init__(
self.exp_hash = '{}_v{}'.format(self.name, version)
self.created_at = str(datetime.utcnow())

init_with_save = False

# update version hash if we need to increase version on our own
# we will increase the previous version, so do it now so the hash
# is accurate
Expand All @@ -71,7 +72,7 @@ def __init__(
# when no version and no file, create it
if not os.path.exists(self.__get_log_name()):
self.__create_exp_file(self.version)
self.save()
init_with_save = True
else:
# otherwise load it
self.__load()
Expand All @@ -81,7 +82,7 @@ def __init__(
old_version = self.__get_last_experiment_version()
self.version = old_version
self.__create_exp_file(self.version + 1)
self.save()
init_with_save = True

# create a git tag if requested
if self.create_git_tag == True:
Expand All @@ -92,7 +93,11 @@ def __init__(
print('Test tube created git tag:', 'tt_{}'.format(self.exp_hash))

# set the tensorboardx log path to the /tf folder in the exp folder
self.logdir = self.get_tensorboardx_path(self.name, self.version)
logdir = self.get_tensorboardx_path(self.name, self.version)
super().__init__(logdir=logdir, *args, **kwargs)

if init_with_save:
self.save()

def argparse(self, argparser):
parsed = vars(argparser)
Expand Down Expand Up @@ -185,26 +190,35 @@ def tag(self, tag_dict):
if self.autosave == True:
self.save()

def log(self, metrics_dict):
def log(self, metrics_dict, main_tag='', global_step=None, walltime=None):
"""
Adds a json dict of metrics.
>> e.log({"loss": 23, "coeff_a": 0.2})
:param metrics_dict:
:tag optional tfx tag
:return:
"""
if self.debug: return

# handle tfx metrics
if global_step is None:
global_step = len(self.metrics)
self.add_scalars(main_tag, metrics_dict, global_step, walltime)

# timestamp
if 'created_at' not in metrics_dict:
metrics_dict['created_at'] = str(datetime.utcnow())

self.__convert_numpy_types(metrics_dict)

self.metrics.append(metrics_dict)

if self.autosave == True:
self.save()


def __convert_numpy_types(self, metrics_dict):
for k, v in metrics_dict.items():
if v.__class__.__name__ == 'float32':
Expand Down Expand Up @@ -249,6 +263,10 @@ def save(self):
df = pd.DataFrame(self.metrics)
df.to_csv(metrics_file_path, index=False)

# whenever we save, we also save tfx

self.export_scalars_to_json(self.get_tensorboardx_scalars_path(self.name, self.version))

def __save_images(self, metrics):
"""
Save tags that have a png_ prefix (as images)
Expand Down Expand Up @@ -347,6 +365,15 @@ def get_tensorboardx_path(self, exp_name, exp_version):
"""
return os.path.join(self.get_data_path(exp_name, exp_version), 'tf')

def get_tensorboardx_scalars_path(self, exp_name, exp_version):
"""
Returns the path to the local package cache
:param path:
:return:
"""
tfx_path = self.get_tensorboardx_path(exp_name, exp_version)
return os.path.join(tfx_path, 'scalars.json')

# ----------------------------
# OVERWRITES
# ----------------------------
Expand All @@ -359,6 +386,11 @@ def __hash__(self):


if __name__ == '__main__':
import math
from time import sleep
e = Experiment()
e.log({'val_loss': 1})

for n_iter in range(2000):
e.log({'xsinx': n_iter * np.sin(n_iter)})
print('done')
e.save()

0 comments on commit 1998dec

Please sign in to comment.