Skip to content

Commit 4af646d

Browse files
committed
self._history_arrays and its getter-function allow for fasting animations which repeatedly converting lists to arrays
1 parent abc66be commit 4af646d

File tree

2 files changed

+57
-31
lines changed

2 files changed

+57
-31
lines changed

ratinabox/Agent.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class Agent:
4545
• initialise_position_and_velocity()
4646
• get_history_slice()
4747
• get_all_default_params()
48-
cache_history_as_arrays()
48+
get_history_arrays()
4949
5050
The default params for this agent are:
5151
default_params = {
@@ -115,7 +115,9 @@ def __init__(self, Environment, params={}):
115115
self.history["vel"] = []
116116
self.history["rot_vel"] = []
117117
self.history["head_direction"] = []
118-
self.history_array_cache = {"last_cache_time":None} # this is used to cache the history data as an arrays for faster plotting/animating
118+
119+
self._last_history_array_cache_time = None
120+
self._history_arrays = {} # this is used to cache the history data as an arrays for faster plotting/animating
119121

120122
self.Neurons = [] # each new Neurons class belonging to this Agent will append itself to this list
121123

@@ -711,11 +713,10 @@ def plot_trajectory(
711713
#get times and trjectory from history data (normal)
712714
t_end = t_end or self_.history["t"][-1]
713715
slice = self_.get_history_slice(t_start=t_start, t_end=t_end, framerate=framerate)
714-
if (self_.history_array_cache["last_cache_time"] != self.t):
715-
self_.cache_history_as_arrays()
716-
time = self_.history_array_cache["t"][slice]
717-
trajectory = self_.history_array_cache["pos"][slice]
718-
head_direction = self_.history_array_cache["head_direction"][slice]
716+
history_data = self.get_history_arrays() # gets history dataframe as dictionary of arrays (only recomputing arrays from lists if necessary)
717+
time = history_data["t"][slice]
718+
trajectory = history_data["pos"][slice]
719+
head_direction = history_data["head_direction"][slice]
719720
else:
720721
# data has been passed in manually
721722
t_start, t_end = time[0], time[-1]
@@ -1068,9 +1069,7 @@ def get_history_slice(self, t_start=None, t_end=None, framerate=None):
10681069
• t_end: end time in seconds (default = self.history["t"][-1])
10691070
• framerate: frames per second (default = None --> step=0 so, just whatever the data frequency (1/Ag.dt) is)
10701071
"""
1071-
if self.history_array_cache["last_cache_time"] != self.t:
1072-
self.cache_history_as_arrays()
1073-
t = self.history_array_cache["t"]
1072+
t = self.get_history_arrays()["t"]
10741073
t_start = t_start or t[0]
10751074
startid = np.nanargmin(np.abs(t - (t_start)))
10761075
t_end = t_end or t[-1]
@@ -1081,14 +1080,14 @@ def get_history_slice(self, t_start=None, t_end=None, framerate=None):
10811080
skiprate = max(1, int((1 / framerate) / self.dt))
10821081

10831082
return slice(startid, endid, skiprate)
1084-
1085-
def cache_history_as_arrays(self):
1086-
"""Converts anything in the current history dictionary into a numpy array along with the time this cache was made. This is useful for speeding up animating functions which require slicing the history data but repeatedly converting to arrays is expensive. This is called automatically by the plot_trajectory function if the history data has not been cached yet.
1087-
TODO This should probably be improved, right now it will convert and cache _all_ history data, even if only some of it is needed."""
1088-
self.history_array_cache = {}
1089-
self.history_array_cache["last_cache_time"] = self.t
1090-
for key in self.history.keys():
1091-
try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function
1092-
self.history_array_cache[key] = np.array(self.history[key])
1093-
except: pass
1094-
return
1083+
1084+
def get_history_arrays(self):
1085+
"""Returns the history dataframe as a dictionary of numpy arrays (as opposed to lists). This getter-function only updates the self._history_arrays if the Agent/Neuron has updates since the last time it was called. This avoids expensive repeated conversion of lists to arrays during animations."""
1086+
if (self._last_history_array_cache_time != self.t):
1087+
self._history_arrays = {}
1088+
self._last_history_array_cache_time = self.t
1089+
for key in self.history.keys():
1090+
try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function
1091+
self._history_arrays[key] = np.array(self.history[key])
1092+
except: pass
1093+
return self._history_arrays

ratinabox/Neurons.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def __init__(self, Agent, params={}):
124124
self.history["firingrate"] = []
125125
self.history["spikes"] = []
126126

127+
self._last_history_array_cache_time = None
128+
self._history_arrays = {} # this dictionary is the same as self.history except the data is in arrays not lists BUT it should only be accessed via its getter-function `self.get_history_arrays()`. This is because the lists are only converted to arrays when they are accessed, not on every step, so as to save time.
129+
127130
self.colormap = "inferno" # default colormap for plotting ratemaps
128131

129132
if ratinabox.verbose is True:
@@ -347,6 +350,7 @@ def plot_rate_map(
347350
"""
348351
#Set kwargs (TODO make lots of params accessible here as kwargs)
349352
spikes_color = kwargs.get("spikes_color", self.color) or "C1"
353+
bin_size = kwargs.get("bin_size", 0.04) #only relevant if you are plotting by method="history"
350354

351355

352356
# GET DATA
@@ -367,25 +371,25 @@ def plot_rate_map(
367371
method = "history"
368372

369373
if method == "history" or spikes == True:
370-
t = np.array(self.history["t"])
374+
history_data = self.get_history_arrays() # converts lists to arrays (if this wasn't just done) and returns them in a dict same as self.history but with arrays not lists
375+
t = history_data["t"]
371376
# times to plot
372377
if len(t) == 0:
373-
print(
374-
"Can't plot rate map by method='history' since there is no available data to plot. "
375-
)
378+
print("Can't plot rate map by method='history', nor plot spikes, since there is no available data to plot. ")
376379
return
377380
t_end = t_end or t[-1]
378381
position_data_agent = kwargs.get("position_data_agent", self.Agent) # In rare cases you may like to plot this cells rate/spike data against the position of a diffferent Agent. This kwarg enables that.
382+
position_agent_history_data = position_data_agent.get_history_arrays()
379383
slice = position_data_agent.get_history_slice(t_start, t_end)
380-
pos = np.array(position_data_agent.history["pos"])[slice]
384+
pos = position_agent_history_data["pos"][slice]
381385
t = t[slice]
382386

383387
if method == "history":
384-
rate_timeseries = np.array(self.history["firingrate"])[slice].T
388+
rate_timeseries = history_data["firingrate"][slice].T
385389
if len(rate_timeseries) == 0:
386390
print("No historical data with which to calculate ratemap.")
387391
if spikes == True:
388-
spike_data = np.array(self.history["spikes"])[slice].T
392+
spike_data = history_data["spikes"][slice].T
389393
if len(spike_data) == 0:
390394
print("No historical data with which to plot spikes.")
391395
if method == "ratemaps_provided":
@@ -468,21 +472,33 @@ def plot_rate_map(
468472
)
469473
im = ax_.imshow(rate_map, extent=ex, zorder=0, cmap=self.colormap)
470474
elif method == "history":
475+
bin_size = kwargs.get("bin_size", 0.05)
471476
rate_timeseries_ = rate_timeseries[chosen_neurons[i], :]
472-
rate_map = utils.bin_data_for_histogramming(
477+
rate_map, zero_bins = utils.bin_data_for_histogramming(
473478
data=pos,
474479
extent=ex,
475-
dx=0.05,
480+
dx=bin_size,
476481
weights=rate_timeseries_,
477482
norm_by_bincount=True,
483+
return_zero_bins=True,
478484
)
485+
#rather than just "nan-ing" the regions where no data was observed we'll plot ontop a "mask" overlay which blocks with a grey square regions where no data was observed. The benefit of this technique is it still allows us to use "bicubic" interpolation which is much smoother than the default "nearest" interpolation.
486+
binary_colors = [(0,0,0,0),ratinabox.LIGHTGREY] #transparent if theres data, grey if there isn't
487+
binary_cmap = matplotlib.colors.ListedColormap(binary_colors)
479488
im = ax_.imshow(
480489
rate_map,
481490
extent=ex,
482491
cmap=self.colormap,
483492
interpolation="bicubic",
484493
zorder=0,
485494
)
495+
no_data_mask = ax_.imshow(
496+
zero_bins,
497+
extent=ex,
498+
cmap=binary_cmap,
499+
interpolation="nearest",
500+
zorder=0.001,
501+
)
486502
ims.append(im)
487503
vmin, vmax = (
488504
min(vmin, np.min(rate_map)),
@@ -749,7 +765,18 @@ def return_list_of_neurons(self, chosen_neurons="all"):
749765
chosen_neurons = list(chosen_neurons.astype(int))
750766

751767
return chosen_neurons
752-
768+
769+
def get_history_arrays(self):
770+
"""Returns the history dataframe as a dictionary of numpy arrays (as opposed to lists). This getter-function only updates the self._history_arrays if the Agent/Neuron has updates since the last time it was called. This avoids expensive repeated conversion of lists to arrays during animations."""
771+
if (self._last_history_array_cache_time != self.Agent.t):
772+
self._history_arrays = {}
773+
self._last_history_array_cache_time = self.Agent.t
774+
for key in self.history.keys():
775+
try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function
776+
self._history_arrays[key] = np.array(self.history[key])
777+
except: pass
778+
return self._history_arrays
779+
753780

754781
"""Specific subclasses """
755782

0 commit comments

Comments
 (0)