diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 8f5b24d..3605704 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -609,7 +609,7 @@ def sample_positions(self, n=10, method="uniform_jitter"): n_remaining = n - n_uniformly_distributed if n_remaining > 0: # sample remaining from available positions with further jittering (delta = delta/2) - positions_remaining = np.array([positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=False)]) + positions_remaining = np.array([positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=True)]) delta /= 2 positions_remaining += np.random.uniform( -0.45 * delta, 0.45 * delta, positions_remaining.shape diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 027561b..2de99a6 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -345,6 +345,10 @@ def plot_rate_map( Returns: fig, ax """ + #Set kwargs (TODO make lots of params accessible here as kwargs) + spikes_color = kwargs.get("spikes_color", self.color) or "C1" + + # GET DATA if method[:11] == "groundtruth": try: @@ -509,7 +513,7 @@ def plot_rate_map( linewidth=0, alpha=0.7, zorder=1.2, - color=(self.color if self.color is not None else "C1"), + color=spikes_color, ) # PLOT 1D @@ -557,7 +561,7 @@ def plot_rate_map( ax.scatter( pos_where_spiked, h, - color=(self.color if self.color is not None else "C1"), + color=spikes_color, alpha=0.5, s=2, linewidth=0, @@ -587,6 +591,7 @@ def plot_angular_rate_map(self, chosen_neurons="all", fig=None, ax=None, autosav subplot_kw={"projection": "polar"}, ) + # get rate maps at all head directions and all positions # the object will end up having shape (n_neurons, n_positions, n_headdirections) rm = np.zeros_like(self.get_state(evaluate_at='all',head_direction=np.array([1,0]))) @@ -611,6 +616,13 @@ def plot_angular_rate_map(self, chosen_neurons="all", fig=None, ax=None, autosav ax[i].tick_params(pad=-20) ax[i].set_xticklabels(["E", "N", "W", "S"]) + + + for i, ax_ in enumerate(axes): + _, ax_ = self.Agent.Environment.plot_environment( + fig, ax_, autosave=False, **kwargs + ) + ratinabox.utils.save_figure(fig, self.name + "_angularratemaps", save=autosave) return fig, ax diff --git a/ratinabox/utils.py b/ratinabox/utils.py index 17ace1b..b72df8f 100644 --- a/ratinabox/utils.py +++ b/ratinabox/utils.py @@ -293,7 +293,7 @@ def rotate(vector, theta): """rotates a vector shape (2,) anticlockwise by angle theta . Args: vector (array): the 2d vector - theta (flaot): the rotation angle + theta (flaot): the rotation angle, radians """ R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) vector_new = np.matmul(R, vector)