Skip to content

Commit 3debe8f

Browse files
rvar
1 parent 690c05c commit 3debe8f

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

momentum/functions.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def kurtosis_update(m: dict, x: float) -> dict:
4444
return m
4545

4646

47-
def rvar(m: dict, x: float = None, alpha=0.01, n=10):
48-
""" Mimics the skater style where we pass {} to initialize """
47+
def rvar(m: dict, x: float = None, rho=0.01, n=10):
48+
""" One function that performs either initialization or an update.
49+
Pass m={} to initialize
50+
"""
4951
if m:
5052
return rvar_update(m=m, x=x)
5153
else:
52-
return rvar_init(rho=alpha, n=n)
54+
return rvar_init(rho=rho, n=n)
5355

5456

5557
def rvar_init(rho: float, n=10) -> dict:
@@ -59,21 +61,21 @@ def rvar_init(rho: float, n=10) -> dict:
5961
"""
6062
assert 0 <= rho <= 1
6163
state = var_init()
62-
state.update({'alpha': rho, 'burnin': n})
64+
state.update({'rho': rho, 'n': n})
6365
return state
6466

6567

6668
def rvar_update(m: dict, x: float) -> dict:
67-
if m['count'] < m['burnin']:
68-
alpha = m['alpha']
69+
if m['count'] < m['n']:
70+
rho = m['rho']
6971
m = var_update(m, x)
70-
m['alpha'] = alpha
72+
m['rho'] = rho
7173
return m
7274
else:
7375
m['count'] += 1
74-
alpha = m['alpha']
75-
m['var'] = (1 - alpha) * (m['var'] + alpha * ((x - m['mean']) ** 2))
76-
m['mean'] = (1 - alpha) * m['mean'] + alpha * x
76+
rho = m['rho']
77+
m['var'] = (1 - rho) * (m['var'] + rho * ((x - m['mean']) ** 2))
78+
m['mean'] = (1 - rho) * m['mean'] + rho * x
7779
m['pvar'] = ((m['count'] - 1) / m['count']) * m['var'] # Not sure this really makes sense :)
7880
if m.get('M2'):
7981
del m['M2']

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="momentum",
10-
version="0.1.1",
10+
version="0.1.2",
1111
description="Running estimates of moments",
1212
long_description=README,
1313
long_description_content_type="text/markdown",

tests/test_forgetting.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from momentum.functions import rvar_init, rvar_update
1+
from momentum.functions import rvar_init, rvar_update, rvar
22
from pprint import pprint
33

44

@@ -12,6 +12,17 @@ def test_forget():
1212
assert -0.5<m['mean']<0.5
1313

1414

15+
def test_forget_other_style():
16+
import numpy as np
17+
xs = list(np.random.randn(5000))+list(2*np.random.randn(5000))
18+
m = rvar(m={},rho=0.01)
19+
for x in xs:
20+
m = rvar_update(m, x)
21+
assert 1.6<m['std']<2.4
22+
assert -0.5<m['mean']<0.5
23+
24+
25+
1526
if __name__=='__main__':
1627
test_forget()
1728

0 commit comments

Comments
 (0)