Skip to content

Commit

Permalink
Decoder Testing + Updates
Browse files Browse the repository at this point in the history
- functions/decoder_tuning.py: Updated function calls and tests of no-taste augmentation and other parameters in decoding.
- functions/decoding_funcs.py: Updated visualization of decoded intervals to include the full decode bin (using the parameter from the parameter dictionary under bayes_params/seg_e_len_time
- functions/dependent_bayes_analysis.py: added/renamed parameters, updated function calls
- functions/dependent_bayes_deviation_analysis.py: added changed parameter calls + began work on making functional
- functions/dependent_decoding_funcs.py: Added no-taste augmentation, added start to deviation decode code
- params/analysis_params_template.json: renamed parameters and separated taste segment decoding time bin from other segment decoding time bin
- utils/test_support.py: continued testing updates.
  • Loading branch information
hfgem committed Oct 15, 2024
1 parent 49484ba commit b416c3d
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 322 deletions.
179 changes: 108 additions & 71 deletions functions/decoder_tuning.py

Large diffs are not rendered by default.

44 changes: 32 additions & 12 deletions functions/decoding_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ def taste_fr_dist_zscore(num_neur, num_cp, tastant_spike_times, segment_spike_ti

def plot_decoded(fr_dist, num_tastes, num_neur, segment_spike_times, tastant_spike_times,
start_dig_in_times, end_dig_in_times, post_taste_dt, pre_taste_dt,
cp_raster_inds, bin_dt, dig_in_names, segment_times,
cp_raster_inds, z_bin_dt, dig_in_names, segment_times,
segment_names, taste_num_deliv, taste_select_epoch,
save_dir, max_decode, max_hz, seg_stat_bin,
neuron_count_thresh, trial_start_frac=0,
neuron_count_thresh, e_len_dt, trial_start_frac=0,
epochs_to_analyze=[], segments_to_analyze=[],
decode_prob_cutoff=0.95):
"""Function to plot the periods when something other than no taste is
Expand All @@ -475,7 +475,8 @@ def plot_decoded(fr_dist, num_tastes, num_neur, segment_spike_times, tastant_spi
(num_cp, num_segments, num_tastes))
epoch_seg_taste_times_best = np.zeros((num_cp, num_segments, num_tastes))
epoch_seg_lengths = np.zeros((num_cp, num_segments, num_tastes))
half_bin_z_dt = np.floor(bin_dt/2).astype('int')
half_bin_z_dt = np.floor(z_bin_dt/2).astype('int')
half_bin_decode_dt = np.floor(e_len_dt/2).astype('int')

if len(epochs_to_analyze) == 0:
epochs_to_analyze = np.arange(num_cp)
Expand Down Expand Up @@ -565,11 +566,25 @@ def plot_decoded(fr_dist, num_tastes, num_neur, segment_spike_times, tastant_spi
# Store binary decoding results
decoded_taste_bin = np.zeros((num_tastes, len(decoded_taste_max)))
for t_i in range(num_tastes):
decoded_taste_bin[t_i, np.where(
decoded_taste_max == t_i)[0]] = 1
times_decoded_taste_max = np.where(decoded_taste_max == t_i)[0]
#Now spread these times based on the full decoding bin
for diff_i in np.arange(-1*half_bin_decode_dt,half_bin_decode_dt):
times_decoded_shifted = times_decoded_taste_max + diff_i
times_decoded_shifted = times_decoded_shifted[np.where((times_decoded_shifted>0)*(times_decoded_shifted<len(decoded_taste_max)))[0]]
decoded_taste_bin[t_i, times_decoded_shifted] = 1
# To ensure starts and ends of bins align
decoded_taste_bin[:, 0] = 0
decoded_taste_bin[:, -1] = 0

#Test for periods that are overlapping and remove from decoded_taste_bin
summed_decode = np.sum(decoded_taste_bin,0)
overlap_bin = (summed_decode > 1).astype('int')
overlap_diff = np.diff(overlap_bin)
overlap_starts = np.where(overlap_diff == 1)[0]+1
overlap_ends = np.where(overlap_diff == -1)[0]+1
for o_i in range(len(overlap_starts)):
decoded_taste_bin[:,overlap_starts[o_i]:overlap_ends[o_i]] = 0
decoded_taste_bin[-1,overlap_starts[o_i]:overlap_ends[o_i]] = 1

# For each taste (except none) calculate start and end times of decoded intervals and plot
all_taste_fr_vecs = []
Expand Down Expand Up @@ -670,20 +685,25 @@ def plot_decoded(fr_dist, num_tastes, num_neur, segment_spike_times, tastant_spi
decoded_taste_prob = seg_decode_epoch_prob[t_i, :]
decoded_taste[decoded_taste_prob < decode_prob_cutoff] = 0
diff_decoded_taste = np.diff(decoded_taste)
start_decoded = np.where(diff_decoded_taste == 1)[0] + 1
end_decoded = np.where(diff_decoded_taste == -1)[0] + 1
start_decoded = np.where(diff_decoded_taste == 1)[0] + 1 - half_bin_decode_dt
end_decoded = np.where(diff_decoded_taste == -1)[0] + 1 + half_bin_decode_dt
num_decoded = len(start_decoded)
num_neur_decoded = np.zeros(num_decoded)
prob_decoded = np.zeros(num_decoded)
for nd_i in range(num_decoded):
d_start = start_decoded[nd_i]
if d_start < 0:
d_start = 0
d_end = end_decoded[nd_i]
if d_end > seg_len:
d_end = seg_len
d_len = d_end-d_start
for n_i in range(num_neur):
if len(np.where(segment_spike_times_s_i_bin[n_i, d_start:d_end])[0]) > 0:
num_neur_decoded[nd_i] += 1
prob_decoded[nd_i] = np.mean(
seg_decode_epoch_prob[t_i, d_start:d_end])
if d_len > 0:
for n_i in range(num_neur):
if len(np.where(segment_spike_times_s_i_bin[n_i, d_start:d_end])[0]) > 0:
num_neur_decoded[nd_i] += 1
prob_decoded[nd_i] = np.mean(
seg_decode_epoch_prob[t_i, d_start:d_end])

# Save the percent taste decoded matching threshold
epoch_seg_taste_times[e_i, s_i, t_i] = np.sum(decoded_taste)
Expand Down
48 changes: 11 additions & 37 deletions functions/dependent_bayes_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def gather_variables(self,):
self.dig_in_names = self.data_dict['dig_in_names']
self.num_tastes = len(self.dig_in_names)
#Bayes Params/Variables
self.skip_time = self.metadata['params_dict']['bayes_params']['skip_time']
self.skip_dt = np.ceil(self.skip_time*1000).astype('int')
self.e_skip_time = self.metadata['params_dict']['bayes_params']['e_skip_time']
self.e_skip_dt = np.ceil(self.e_skip_time*1000).astype('int')
self.fr_bins = self.metadata['params_dict']['bayes_params']['fr_bins']
self.e_len_time = self.metadata['params_dict']['bayes_params']['e_len_time']
self.e_len_dt = np.ceil(self.e_len_time*1000).astype('int')
self.taste_e_len_time = self.metadata['params_dict']['bayes_params']['taste_e_len_time']
self.taste_e_len_dt = np.ceil(self.e_len_time*1000).astype('int')
self.seg_e_len_time = self.metadata['params_dict']['bayes_params']['seg_e_len_time']
self.seg_e_len_dt = np.ceil(self.seg_e_len_time*1000).astype('int')
self.bayes_fr_bins = self.metadata['params_dict']['bayes_params']['fr_bins']
self.neuron_count_thresh = self.metadata['params_dict']['bayes_params']['neuron_count_thresh']
self.max_decode = self.metadata['params_dict']['bayes_params']['max_decode']
self.seg_stat_bin = self.metadata['params_dict']['bayes_params']['seg_stat_bin']
Expand All @@ -93,7 +94,7 @@ def gather_variables(self,):
def pull_fr_dist(self,):
print("\tPulling FR Distributions")
tastant_fr_dist_pop, taste_num_deliv, max_hz_pop = ddf.taste_fr_dist(self.num_neur, self.tastant_spike_times,
self.pop_taste_cp_raster_inds, self.fr_bins,
self.pop_taste_cp_raster_inds, self.bayes_fr_bins,
self.start_dig_in_times, self.pre_taste_dt,
self.post_taste_dt, self.trial_start_frac)
self.tastant_fr_dist_pop = tastant_fr_dist_pop
Expand All @@ -102,10 +103,9 @@ def pull_fr_dist(self,):
tastant_fr_dist_z_pop, taste_num_deliv, max_hz_z_pop, min_hz_z_pop = ddf.taste_fr_dist_zscore(self.num_neur, self.tastant_spike_times,
self.segment_spike_times, self.segment_names,
self.segment_times, self.pop_taste_cp_raster_inds,
self.fr_bins, self.start_dig_in_times, self.pre_taste_dt,
self.bayes_fr_bins, self.start_dig_in_times, self.pre_taste_dt,
self.post_taste_dt, self.bin_dt, self.trial_start_frac)
self.tastant_fr_dist_z_pop = tastant_fr_dist_z_pop
self.taste_num_deliv = taste_num_deliv
self.max_hz_z_pop = max_hz_z_pop
self.min_hz_z_pop = min_hz_z_pop

Expand All @@ -121,7 +121,7 @@ def run_decoder_tuning(self,):
self.tastant_spike_times, self.cur_dist,
self.pop_taste_cp_raster_inds, self.pre_taste_dt, self.post_taste_dt,
self.epochs_to_analyze, self.select_neur, self.e_skip_dt,
self.e_len_dt, self.main_decode_dir)
self.taste_e_len_dt, self.main_decode_dir)


def decode_all_neurons(self,):
Expand All @@ -137,27 +137,14 @@ def decode_all_neurons(self,):
if os.path.isdir(self.decode_dir) == False:
os.mkdir(self.decode_dir)
ddf.decode_epochs(self.cur_dist, self.segment_spike_times, self.post_taste_dt,
self.e_skip_dt, self.e_len_dt, self.dig_in_names,
self.e_skip_dt, self.seg_e_len_dt, self.dig_in_names,
self.segment_times, self.segment_names, self.start_dig_in_times,
self.taste_num_deliv, self.select_neur, self.max_hz_pop,
self.decode_dir, self.neuron_count_thresh, self.trial_start_frac,
False, self.epochs_to_analyze, self.segments_to_analyze)

self.plot_decode_results()

#Run nb decoder over rest intervals
self.decode_dir = self.main_decode_dir + 'nb/'
if os.path.isdir(self.decode_dir) == False:
os.mkdir(self.decode_dir)
ddf.decode_epochs_nb(self.cur_dist, self.segment_spike_times, self.post_taste_dt,
self.e_skip_dt, self.e_len_dt, self.dig_in_names,
self.segment_times, self.segment_names, self.start_dig_in_times,
self.taste_num_deliv, self.select_neur, self.max_hz_pop,
self.decode_dir, self.neuron_count_thresh, self.trial_start_frac,
self.epochs_to_analyze, self.segments_to_analyze)

self.plot_decode_results()

def decode_selective_neurons(self,):
print("\tDecoding taste selective neurons")

Expand Down Expand Up @@ -190,7 +177,7 @@ def decode_all_neurons_zscored(self,):
os.mkdir(self.decode_dir)

ddf.decode_epochs(self.cur_dist, self.segment_spike_times, self.post_taste_dt,
self.e_skip_dt, self.e_len_dt, self.dig_in_names,
self.e_skip_dt, self.seg_e_len_dt, self.dig_in_names,
self.segment_times, self.segment_names, self.start_dig_in_times,
self.taste_num_deliv, self.select_neur, self.max_hz_z_pop,
self.decode_dir, self.neuron_count_thresh, self.trial_start_frac,
Expand All @@ -199,19 +186,6 @@ def decode_all_neurons_zscored(self,):

self.plot_decode_results()

#Run nb decoder over rest intervals
self.decode_dir = self.main_decode_dir + 'nb/'
if os.path.isdir(self.decode_dir) == False:
os.mkdir(self.decode_dir)
ddf.decode_epochs_nb_zscore(self.cur_dist, self.segment_spike_times, self.post_taste_dt,
self.e_skip_dt, self.e_len_dt, self.dig_in_names, self.segment_times,
self.bin_dt, self.segment_names, self.start_dig_in_times,
self.taste_num_deliv, self.select_neur, self.max_hz_z_pop,
self.decode_dir, self.neuron_count_thresh, self.trial_start_frac,
self.epochs_to_analyze,self.segments_to_analyze)

self.plot_decode_results()

def decode_selective_neurons_zscored(self,):
print("\tDecoding taste selective neurons")

Expand All @@ -223,7 +197,7 @@ def decode_selective_neurons_zscored(self,):
self.cur_dist = self.tastant_fr_dist_z_pop

ddf.decode_epochs_zscore(self.cur_dist, self.segment_spike_times, self.post_taste_dt,
self.e_skip_dt, self.e_len_dt, self.dig_in_names, self.segment_times,
self.e_skip_dt, self.seg_e_len_dt, self.dig_in_names, self.segment_times,
self.bin_dt, self.segment_names, self.start_dig_in_times,
self.taste_num_deliv, self.select_neur, self.max_hz_z_pop,
self.decode_dir, self.neuron_count_thresh, self.trial_start_frac,
Expand Down
28 changes: 14 additions & 14 deletions functions/dependent_bayes_deviation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def gather_variables(self,):
self.min_dev_size = self.metadata['params_dict']['min_dev_size']
# Decoding Params/Variables
self.e_skip_time = self.metadata['params_dict']['bayes_params']['e_skip_time']
self.e_len_time = self.metadata['params_dict']['bayes_params']['e_len_time']
self.e_skip_dt = np.ceil(self.e_skip_time*1000).astype('int')
self.e_len_dt = np.ceil(self.e_len_time*1000).astype('int')
self.fr_bins = self.metadata['params_dict']['bayes_params']['fr_bins']
self.taste_e_len_time = self.metadata['params_dict']['bayes_params']['taste_e_len_time']
self.taste_e_len_dt = np.ceil(self.e_len_time*1000).astype('int')
self.seg_e_len_time = self.metadata['params_dict']['bayes_params']['seg_e_len_time']
self.seg_e_len_dt = np.ceil(self.seg_e_len_time*1000).astype('int')
self.bayes_fr_bins = self.metadata['params_dict']['bayes_params']['fr_bins']
self.neuron_count_thresh = self.metadata['params_dict']['bayes_params']['neuron_count_thresh']
self.max_decode = self.metadata['params_dict']['bayes_params']['max_decode']
self.seg_stat_bin = self.metadata['params_dict']['bayes_params']['seg_stat_bin']
Expand Down Expand Up @@ -118,20 +120,18 @@ def import_deviations(self,):

def pull_fr_dist(self,):
print("\tPulling FR Distributions")
tastant_fr_dist_pop, taste_num_deliv, max_hz_pop = ddf.taste_fr_dist(self.num_neur,
self.num_cp, self.tastant_spike_times,
self.pop_taste_cp_raster_inds,
self.start_dig_in_times, self.pre_taste_dt,
self.post_taste_dt, self.trial_start_frac)
tastant_fr_dist_pop, taste_num_deliv, max_hz_pop = ddf.taste_fr_dist(self.num_neur, self.tastant_spike_times,
self.pop_taste_cp_raster_inds, self.bayes_fr_bins,
self.start_dig_in_times, self.pre_taste_dt,
self.post_taste_dt, self.trial_start_frac)
self.tastant_fr_dist_pop = tastant_fr_dist_pop
self.taste_num_deliv = taste_num_deliv
self.max_hz_pop = max_hz_pop
tastant_fr_dist_z_pop, _, max_hz_z_pop, min_hz_z_pop = ddf.taste_fr_dist_zscore(self.num_neur,
self.num_cp, self.tastant_spike_times,
self.segment_spike_times, self.segment_names,
self.segment_times, self.pop_taste_cp_raster_inds,
self.start_dig_in_times, self.pre_taste_dt,
self.post_taste_dt, self.bin_dt, self.trial_start_frac)
tastant_fr_dist_z_pop, taste_num_deliv, max_hz_z_pop, min_hz_z_pop = ddf.taste_fr_dist_zscore(self.num_neur, self.tastant_spike_times,
self.segment_spike_times, self.segment_names,
self.segment_times, self.pop_taste_cp_raster_inds,
self.bayes_fr_bins, self.start_dig_in_times, self.pre_taste_dt,
self.post_taste_dt, self.bin_dt, self.trial_start_frac)
self.tastant_fr_dist_z_pop = tastant_fr_dist_z_pop
self.max_hz_z_pop = max_hz_z_pop
self.min_hz_z_pop = min_hz_z_pop
Expand Down
Loading

0 comments on commit b416c3d

Please sign in to comment.