-
Notifications
You must be signed in to change notification settings - Fork 2
/
grid_cell_demo.py
156 lines (134 loc) · 4.61 KB
/
grid_cell_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/python
import argparse
import numpy as np
import random
import itertools
import math
import scipy.signal
import time
import matplotlib.pyplot as plt
from grid_cells import GridCells
from nupic.encoders.coordinate import CoordinateEncoder
class Environment(object):
"""
Environment is a 2D square, in first quadrant with corner at origin.
"""
def __init__(self, size):
self.size = size
self.position = (size/2, size/2)
self.speed = 2.0 ** .5
self.angle = 0
self.course = []
def in_bounds(self, position):
x, y = position
x_in = x >= 0 and x < self.size
y_in = y >= 0 and y < self.size
return x_in and y_in
def move(self):
max_rotation = 2 * math.pi / 20
self.angle += random.uniform(-max_rotation, max_rotation)
vx = self.speed * math.cos(self.angle)
vy = self.speed * math.sin(self.angle)
x, y = self.position
new_position = (x + vx, y + vy)
if self.in_bounds(new_position):
self.position = new_position
self.course.append(self.position)
else:
# On failure, recurse and try again.
assert(self.in_bounds(self.position))
self.angle = random.uniform(0, 2 * math.pi)
self.move()
def plot_course(self, show=True):
plt.figure("Path")
plt.ylim([0, self.size])
plt.xlim([0, self.size])
x, y = zip(*self.course)
plt.plot(x, y, 'k-')
if show:
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--train_time', type=int, default = 1000 * 1000,)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
# Setup
env = Environment(size = 200)
enc = CoordinateEncoder(w=75, n=2500)
assert(enc.w > 50)
enc_radius = 5
gcm = GridCells(
b1 = 0.05,
b2 = 0.05 / 3,
inputDimensions = (enc.n,),
columnDimensions = (100,),
potentialPct = 0.95,
numActiveColumnsPerInhArea = .2 * 100,
synPermInactiveDec = 0.0008,
synPermActiveInc = 0.005,
synPermConnected = 0.25,
stimulusThreshold = 0,
boostStrength = 0.,
globalInhibition = True,
potentialRadius = enc.n,
wrapAround = True,)
def compute(learn=True):
nearest_position = np.array(np.rint(env.position), dtype=np.int)
enc_sdr = np.zeros(enc.n)
enc.encodeIntoArray((nearest_position, enc_radius), enc_sdr)
gc_act = np.zeros(gcm.getNumColumns())
gcm.compute(enc_sdr, learn, gc_act)
return enc_sdr, gc_act
print("Training for %d cycles ..."%args.train_time)
start_time = time.time()
gcm.reset()
for step in range(args.train_time):
if step % 1000 == 0:
print("Cycle %d"%step)
env.move()
compute()
train_time = time.time()
print("Elapsed time (training): %d seconds."%int(round(train_time - start_time)))
print("Testing ...")
# Show how the agent traversed the environment.
env.plot_course(show=False)
# Measure Receptive Fields.
enc_num_samples = 12
gc_num_samples = 20
enc_samples = random.sample(xrange(enc.n), enc_num_samples)
gc_samples = random.sample(xrange(gcm.getNumColumns()), gc_num_samples)
enc_rfs = [np.zeros((env.size, env.size)) for idx in enc_samples]
gc_rfs = [np.zeros((env.size, env.size)) for idx in gc_samples]
for position in itertools.product(xrange(env.size), xrange(env.size)):
env.position = position
gcm.reset()
enc_sdr, gc_sdr = compute(learn=False)
for rf_idx, enc_idx in enumerate(enc_samples):
enc_rfs[rf_idx][position] = enc_sdr[enc_idx]
for rf_idx, gc_idx in enumerate(gc_samples):
gc_rfs[rf_idx][position] = gc_sdr[gc_idx]
# Show the Input/Encoder Receptive Fields.
if enc_num_samples > 0:
plt.figure("Input Receptive Fields")
nrows = int(enc_num_samples ** .5)
ncols = math.ceil((enc_num_samples+.0) / nrows)
for subplot_idx, rf in enumerate(enc_rfs):
plt.subplot(nrows, ncols, subplot_idx + 1)
plt.imshow(rf)
# Show the Grid Cells Receptive Fields.
if gc_num_samples > 0:
plt.figure("Grid Cell Receptive Fields")
nrows = int(gc_num_samples ** .5)
ncols = math.ceil((gc_num_samples+.0) / nrows)
for subplot_idx, rf in enumerate(gc_rfs):
plt.subplot(nrows, ncols, subplot_idx + 1)
plt.imshow(rf)
# Show the autocorrelations of the grid cell receptive fields.
plt.figure("Grid Cell RF Autocorrelations")
for subplot_idx, rf in enumerate(gc_rfs):
plt.subplot(nrows, ncols, subplot_idx + 1)
xcor = scipy.signal.correlate2d(rf, rf)
plt.imshow(xcor)
test_time = time.time()
print("Elapsed time (testing): %d seconds."%int(round(test_time - train_time)))
plt.show()