-
Notifications
You must be signed in to change notification settings - Fork 0
/
statemodifier.py
48 lines (36 loc) · 1.22 KB
/
statemodifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
class DefaultModifier():
def __init__(self):
pass
def apply(self, state):
return self.modify(state)
def modify(self, state):
return state
def set_ckpt(self, ckpt):
pass
def get_ckpt(self):
return {}
class ClassicModifier():
def __init__(self):
self._n = 0
def apply(self, state):
self._n += 1
if self._n == 1:
self._mean = state
self._std = torch.zeros(len(state))
else:
prev_mean = self._mean.clone()
self._mean = prev_mean + (state - prev_mean) / self._n
self._std = self._std + (state - prev_mean) * (state - self._mean)
return self.modify(state)
def modify(self, state):
if self._n == 0: return state
elif self._n == 1: norm = torch.zeros(state.size())
else: norm = (state - self._mean) / (1e-8 + torch.sqrt(torch.div(self._std, self._n-1)))
return torch.clamp(norm, -5, 5)
def set_ckpt(self, ckpt):
self._n = ckpt['mod_n']
self._mean = ckpt['mod_mean']
self._std = ckpt['mod_std']
def get_ckpt(self):
return {'mod_n' : self._n, 'mod_mean' : self._mean, 'mod_std' : self._std}