Skip to content

Commit

Permalink
big fixes and minor stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
TomGeorge1234 committed Feb 10, 2024
1 parent 8a8910a commit ad34741
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def sample_positions(self, n=10, method="uniform_jitter"):
n_remaining = n - n_uniformly_distributed
if n_remaining > 0:
# sample remaining from available positions with further jittering (delta = delta/2)
positions_remaining = np.array([positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=False)])
positions_remaining = np.array([positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=True)])
delta /= 2
positions_remaining += np.random.uniform(
-0.45 * delta, 0.45 * delta, positions_remaining.shape
Expand Down
16 changes: 14 additions & 2 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ def plot_rate_map(
Returns:
fig, ax
"""
#Set kwargs (TODO make lots of params accessible here as kwargs)
spikes_color = kwargs.get("spikes_color", self.color) or "C1"


# GET DATA
if method[:11] == "groundtruth":
try:
Expand Down Expand Up @@ -509,7 +513,7 @@ def plot_rate_map(
linewidth=0,
alpha=0.7,
zorder=1.2,
color=(self.color if self.color is not None else "C1"),
color=spikes_color,
)

# PLOT 1D
Expand Down Expand Up @@ -557,7 +561,7 @@ def plot_rate_map(
ax.scatter(
pos_where_spiked,
h,
color=(self.color if self.color is not None else "C1"),
color=spikes_color,
alpha=0.5,
s=2,
linewidth=0,
Expand Down Expand Up @@ -587,6 +591,7 @@ def plot_angular_rate_map(self, chosen_neurons="all", fig=None, ax=None, autosav
subplot_kw={"projection": "polar"},
)


# get rate maps at all head directions and all positions
# the object will end up having shape (n_neurons, n_positions, n_headdirections)
rm = np.zeros_like(self.get_state(evaluate_at='all',head_direction=np.array([1,0])))
Expand All @@ -611,6 +616,13 @@ def plot_angular_rate_map(self, chosen_neurons="all", fig=None, ax=None, autosav
ax[i].tick_params(pad=-20)
ax[i].set_xticklabels(["E", "N", "W", "S"])



for i, ax_ in enumerate(axes):
_, ax_ = self.Agent.Environment.plot_environment(
fig, ax_, autosave=False, **kwargs
)

ratinabox.utils.save_figure(fig, self.name + "_angularratemaps", save=autosave)

return fig, ax
Expand Down
2 changes: 1 addition & 1 deletion ratinabox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def rotate(vector, theta):
"""rotates a vector shape (2,) anticlockwise by angle theta .
Args:
vector (array): the 2d vector
theta (flaot): the rotation angle
theta (flaot): the rotation angle, radians
"""
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
vector_new = np.matmul(R, vector)
Expand Down

0 comments on commit ad34741

Please sign in to comment.