-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
142 lines (105 loc) · 3.47 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import functools
import heapq
import time
from contextlib import contextmanager
from blox import AttrDict, rmap, rmap_list
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
class PriorityQueue:
def __init__(self):
self._queue = []
self._index = 0
def push(self, item, priority):
heapq.heappush(self._queue, (-priority, self._index, item))
self._index += 1
def pop(self):
return heapq.heappop(self._queue)[-1]
class RecursiveAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = None
self.avg = None
self.sum = None
self.count = 0
def update(self, val):
self.val = val
if self.sum is None:
self.sum = val
else:
self.sum = rmap_list(lambda x, y: x + y, [self.sum, val])
self.count += 1
self.avg = rmap(lambda x: x / self.count, self.sum)
@contextmanager
def dummy_context():
yield
@contextmanager
def timing(text, name=None, interval=10):
start = time.time()
yield
elapsed = time.time() - start
if name:
if not hasattr(timing, name):
setattr(timing, name, AverageMeter())
meter = getattr(timing, name)
meter.update(elapsed)
if meter.count % interval == 0:
print("{} {}".format(text, meter.avg))
return
print("{} {}".format(text, elapsed))
class timed:
""" A function decorator that prints the elapsed time """
def __init__(self, text):
""" Decorator parameters """
self.text = text
def __call__(self, func):
""" Wrapping """
def wrapper(*args, **kwargs):
with timing(self.text):
result = func(*args, **kwargs)
return result
return wrapper
def lazy_property(function):
""" Caches the property such that the code creating it is only executed once.
Adapted from Dani Hafner (https://danijar.com/structuring-your-tensorflow-models/) """
# TODO can I just use lru_cache?
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
class HasParameters:
def __init__(self, **kwargs):
self.build_params(kwargs)
def build_params(self, inputs):
# If params undefined define params
try:
self.params
except AttributeError:
self.params = self.get_default_params()
self.params.update(inputs)
# TODO allow to access parameters by self.<param>
class ParamDict(AttrDict):
def overwrite(self, new_params):
for param in new_params:
print('overriding param {} to value {}'.format(param, new_params[param]))
self.__setattr__(param, new_params[param])
return self