From c7c2f5bc7225dd5aa03b4c97cece28ca5b818d90 Mon Sep 17 00:00:00 2001 From: Philipp Hornauer Date: Thu, 10 Oct 2024 15:59:47 +0200 Subject: [PATCH] Updated and ran pre-commit hook --- .gitignore | 2 +- .pre-commit-config.yaml | 21 + .../spike_sorting-checkpoint.py | 363 ++++++++------ axon_tracking/quantification.py | 59 ++- axon_tracking/skeletonization.py | 441 +++++++++++------- axon_tracking/spike_sorting.py | 371 +++++++++------ axon_tracking/template_extraction.py | 407 ++++++++++------ axon_tracking/visualization.py | 428 ++++++++++------- 8 files changed, 1332 insertions(+), 760 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.gitignore b/.gitignore index 48eb5e5..5b00553 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,4 @@ dmypy.json # Neuron *.c -*.o \ No newline at end of file +*.o diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3530f8b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +# .pre-commit-config.yaml +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 # this is optional, use `pre-commit autoupdate` to get the latest rev! + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + +- repo: https://github.com/adamchainz/blacken-docs + rev: v1.12.1 # replace with latest tag on GitHub + hooks: + - id: blacken-docs + additional_dependencies: + - black==24.10.0 diff --git a/axon_tracking/.ipynb_checkpoints/spike_sorting-checkpoint.py b/axon_tracking/.ipynb_checkpoints/spike_sorting-checkpoint.py index 8b7e452..1039946 100644 --- a/axon_tracking/.ipynb_checkpoints/spike_sorting-checkpoint.py +++ b/axon_tracking/.ipynb_checkpoints/spike_sorting-checkpoint.py @@ -5,9 +5,17 @@ from tqdm import tqdm from glob import glob -def sort_recording_list(path_list, save_path_changes, sorter, sorter_params = dict(), clear_files=True, verbose=True): + +def sort_recording_list( + path_list, + save_path_changes, + sorter, + sorter_params=dict(), + clear_files=True, + verbose=True, +): """ - Function that iterates over a list of axon scans, finds common electrodes, concatenates and spike sorts the recording slices. + Function that iterates over a list of axon scans, finds common electrodes, concatenates and spike sorts the recording slices. Arguments ---------- @@ -27,30 +35,41 @@ def sort_recording_list(path_list, save_path_changes, sorter, sorter_params = di Returns ---------- - sorting_list: list of sorting objects + sorting_list: list of sorting objects Specific type depends on the sorter. - + """ - + sorting_list = [] - + for rec_path in tqdm(path_list, desc="Sorting recordings"): - + h5 = h5py.File(rec_path) - #Check that all wells are recorded throughout all recordings (should not fail) - stream_ids = list(h5['wells'].keys()) + # Check that all wells are recorded throughout all recordings (should not fail) + stream_ids = list(h5["wells"].keys()) save_root = convert_rec_path_to_save_path(rec_path, save_path_changes) - + for stream_id in tqdm(stream_ids, desc="Sorting wells"): - sorter_output_file = Path(os.path.join(save_root, stream_id, 'sorter_output', 'amplitudes.npy')) + sorter_output_file = Path( + os.path.join(save_root, stream_id, "sorter_output", "amplitudes.npy") + ) if not os.path.exists(sorter_output_file): - multirecording, common_el = concatenate_recording_slices(rec_path, stream_id) - sorting = clean_sorting(multirecording, save_root, stream_id, sorter, sorter_params, clear_files=clear_files, verbose=verbose) + multirecording, common_el = concatenate_recording_slices( + rec_path, stream_id + ) + sorting = clean_sorting( + multirecording, + save_root, + stream_id, + sorter, + sorter_params, + clear_files=clear_files, + verbose=verbose, + ) sorting_list.append(sorting) - - return sorting_list + return sorting_list def convert_rec_path_to_save_path(rec_path, save_path_changes): @@ -64,32 +83,33 @@ def convert_rec_path_to_save_path(rec_path, save_path_changes): save_path_changes: dict Dictionary containing keys 'pos' and 'vals' that indicate the changes to be made to the rec_path. Refer to the inidices after splitting the path by '/'. - + Returns ---------- save_path: str Root save path. Well ID will be appended during the sorting. """ - path_parts = rec_path.split('/') - for x,y in zip(save_path_changes['pos'], save_path_changes['vals']): + path_parts = rec_path.split("/") + for x, y in zip(save_path_changes["pos"], save_path_changes["vals"]): path_parts[x] = y - + save_path = os.path.join(*path_parts) - + return save_path + def find_common_electrodes(rec_path, stream_id): """ Function that returns the common electrodes of the successive axon scan recordings. - + Arguments ---------- rec_path: str Path to the axon scan file. stream_id: str Well ID in the format "well***"; Well 1 would be "well001", Well 20 would be "well020" - + Returns ---------- rec_names: list @@ -97,36 +117,36 @@ def find_common_electrodes(rec_path, stream_id): common_el: list List of electrodes that are present in all axon scan recordings. """ - - assert(os.path.exists(rec_path)) - + + assert os.path.exists(rec_path) + h5 = h5py.File(rec_path) - rec_names = list(h5['wells'][stream_id].keys()) + rec_names = list(h5["wells"][stream_id].keys()) pos = dict() - x, y = np.full([1,26400], np.nan), np.full([1,26400], np.nan) - + x, y = np.full([1, 26400], np.nan), np.full([1, 26400], np.nan) for i, rec_name in enumerate(rec_names): - #rec_name = 'rec' + '%0*d' % (4, rec_id) - rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id, rec_name=rec_name) + # rec_name = 'rec' + '%0*d' % (4, rec_id) + rec = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_name + ) rec_el = rec.get_property("contact_vector")["electrode"] - x[:,rec_el] = rec.get_property("contact_vector")['x'] - y[:,rec_el] = rec.get_property("contact_vector")['y'] - + x[:, rec_el] = rec.get_property("contact_vector")["x"] + y[:, rec_el] = rec.get_property("contact_vector")["y"] + if i == 0: common_el = rec_el else: common_el = list(set(common_el).intersection(rec_el)) - pos = {'x': x[0], 'y': y[0]} - - return rec_names, common_el, pos + pos = {"x": x[0], "y": y[0]} + return rec_names, common_el, pos def concatenate_recording_slices(rec_path, stream_id, center=True): """ - Function that centers and concatenates the recordings of an axon scan for all common electrodes. + Function that centers and concatenates the recordings of an axon scan for all common electrodes. Arguments ---------- @@ -134,7 +154,7 @@ def concatenate_recording_slices(rec_path, stream_id, center=True): Path to the axon scan file. stream_id: str Well ID in the format "well***"; Well 1 would be "well001", Well 20 would be "well020" - + Returns ---------- multirecording: ConcatenatedRecordingSlice @@ -143,56 +163,76 @@ def concatenate_recording_slices(rec_path, stream_id, center=True): rec_names, common_el, pos = find_common_electrodes(rec_path, stream_id) if len(rec_names) == 1: - rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id, rec_name=rec_names[0]) + rec = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_names[0] + ) return rec else: rec_list = [] - for rec_name in rec_names: - #rec_name = 'rec' + '%0*d' % (4, r) - rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id, rec_name=rec_name) - - ch_id = rec.get_property("contact_vector")['device_channel_indices'] + for rec_name in rec_names: + # rec_name = 'rec' + '%0*d' % (4, r) + rec = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_name + ) + + ch_id = rec.get_property("contact_vector")["device_channel_indices"] rec_el = rec.get_property("contact_vector")["electrode"] - + chan_idx = [np.where(rec_el == el)[0][0] for el in common_el] sel_channels = rec.get_channel_ids()[chan_idx] if center: - chunk_size = np.min([10000, rec.get_num_samples()]) - 100 #Fallback for ultra short recordings (too little activity) - rec = si.center(rec,chunk_size=chunk_size) - - rec_list.append(rec.channel_slice(sel_channels, renamed_channel_ids=list(range(len(chan_idx))))) - + chunk_size = ( + np.min([10000, rec.get_num_samples()]) - 100 + ) # Fallback for ultra short recordings (too little activity) + rec = si.center(rec, chunk_size=chunk_size) + + rec_list.append( + rec.channel_slice( + sel_channels, renamed_channel_ids=list(range(len(chan_idx))) + ) + ) + multirecording = si.concatenate_recordings(rec_list) - + return multirecording, common_el, pos + def intersect_and_concatenate_recording_list(rec_list): - assert(len(rec_list) > 1) + assert len(rec_list) > 1 rec_el_list = [] sliced_rec_list = [] - for i,rec in enumerate(rec_list): + for i, rec in enumerate(rec_list): rec_el = rec.get_property("contact_vector")["electrode"] rec_el_list.append(rec_el) if i == 0: common_el = rec_el else: common_el = list(set(common_el).intersection(rec_el)) - + for i, els in enumerate(rec_el_list): chan_idx = [np.where(els == el)[0][0] for el in common_el] channel_ids = rec_list[i].get_channel_ids()[chan_idx] - slice_rec = rec_list[i].channel_slice(channel_ids, renamed_channel_ids=list(range(len(channel_ids)))) - sliced_rec_list.append(slice_rec.astype('float32')) - + slice_rec = rec_list[i].channel_slice( + channel_ids, renamed_channel_ids=list(range(len(channel_ids))) + ) + sliced_rec_list.append(slice_rec.astype("float32")) + concatenated = si.concatenate_recordings(sliced_rec_list) return concatenated - -def clean_sorting(rec, save_root, stream_id, sorter, sorter_params = dict(), clear_files=True, verbose=True): +def clean_sorting( + rec, + save_root, + stream_id, + sorter, + sorter_params=dict(), + clear_files=True, + verbose=True, +): """ Function that creates output folder if it does not exist, sorts the recording using the specified sorter - and clears up large files afterwards. + and clears up large files afterwards. Arguments ---------- @@ -213,39 +253,49 @@ def clean_sorting(rec, save_root, stream_id, sorter, sorter_params = dict(), cle Returns ---------- - sorting: Sorting object + sorting: Sorting object Specific type depends on the sorter. """ - + output_folder = Path(os.path.join(save_root, stream_id)) - sorter_output_file = os.path.join(output_folder, 'sorter_output', 'amplitudes.npy') + sorter_output_file = os.path.join(output_folder, "sorter_output", "amplitudes.npy") sorting = [] # Creates output folder if sorting has not yet been done if os.path.exists(sorter_output_file): return sorting - elif (rec.get_total_duration() < 30): - full_output_folder = Path(os.path.join(output_folder, 'sorter_output')) + elif rec.get_total_duration() < 30: + full_output_folder = Path(os.path.join(output_folder, "sorter_output")) full_output_folder.mkdir(parents=True, exist_ok=True) - np.save(sorter_output_file, np.empty(0)) #Empty file to indicate a failed sorting for future loops + np.save( + sorter_output_file, np.empty(0) + ) # Empty file to indicate a failed sorting for future loops return sorting else: - #output_folder.mkdir(parents=True, exist_ok=True) - raw_file = os.path.join(output_folder, 'sorter_output', 'recording.dat') - wh_file = os.path.join(output_folder, 'sorter_output', 'temp_wh.dat') + # output_folder.mkdir(parents=True, exist_ok=True) + raw_file = os.path.join(output_folder, "sorter_output", "recording.dat") + wh_file = os.path.join(output_folder, "sorter_output", "temp_wh.dat") if verbose: - print(f"DURATION: {rec.get_num_frames() / rec.get_sampling_frequency()} s -- " - f"NUM. CHANNELS: {rec.get_num_channels()}") + print( + f"DURATION: {rec.get_num_frames() / rec.get_sampling_frequency()} s -- " + f"NUM. CHANNELS: {rec.get_num_channels()}" + ) # We use try/catch to not break loops when iterating over several sortings (e.g. when not all wells were recorded) try: t_start_sort = time.time() - sorting = si.run_sorter(sorter, rec, output_folder=output_folder, verbose=verbose, remove_existing_folder=True, - **sorter_params) + sorting = si.run_sorter( + sorter, + rec, + output_folder=output_folder, + verbose=verbose, + remove_existing_folder=True, + **sorter_params, + ) if verbose: print(f"\n\nSpike sorting elapsed time {time.time() - t_start_sort} s") - - #Making sure we clean up the largest temporary files + + # Making sure we clean up the largest temporary files if clear_files & os.path.exists(wh_file): os.remove(wh_file) if clear_files & os.path.exists(raw_file): @@ -256,9 +306,10 @@ def clean_sorting(rec, save_root, stream_id, sorter, sorter_params = dict(), cle os.remove(wh_file) if clear_files & os.path.exists(raw_file): os.remove(raw_file) - + return sorting + def generate_rec_list(path_parts): """ Function that takes a list of strings (path parts) and finds all recordings matching the path pattern, and returns the stream ids for the first recordings. @@ -276,42 +327,45 @@ def generate_rec_list(path_parts): List of stream_ids (wells) recorded from the first recording. """ path_pattern = os.path.join(*path_parts) - path_list = glob(path_pattern) + path_list = glob(path_pattern) h5 = h5py.File(path_list[-1]) - stream_ids = list(h5['wells'].keys()) + stream_ids = list(h5["wells"].keys()) path_list.sort() - - return path_list, stream_ids + + return path_list, stream_ids + def concatenate_recording_list(path_list, stream_id): well_recording_list = [] - for rec_path in path_list: #Iterate over recordings to be concatenated - try: # If not all wells were recorded, should be the only cause for an error - rec = si.MaxwellRecordingExtractor(rec_path,stream_id=stream_id) + for rec_path in path_list: # Iterate over recordings to be concatenated + try: # If not all wells were recorded, should be the only cause for an error + rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id) well_recording_list.append(rec) except Exception: continue - + if len(well_recording_list) == len(path_list): multirecording = si.concatenate_recordings(well_recording_list) else: - raise ValueError('Could not load all recordings!') - + raise ValueError("Could not load all recordings!") + saturated_count = find_saturated_channels(well_recording_list) - clean_multirecording = multirecording.remove_channels(multirecording.get_channel_ids()[saturated_count>0]) - - + clean_multirecording = multirecording.remove_channels( + multirecording.get_channel_ids()[saturated_count > 0] + ) + return clean_multirecording + def cut_concatenated_recording(concat_rec, cutout=np.inf): - rec_list = concat_rec._kwargs['recording_list'] + rec_list = concat_rec._kwargs["recording_list"] sliced_list = [] for rec in rec_list: duration = rec.get_total_duration() if cutout < duration: end_frame = rec.get_num_frames() start_frame = end_frame - cutout * rec.get_sampling_frequency() - sliced_rec = rec.frame_slice(start_frame,end_frame) + sliced_rec = rec.frame_slice(start_frame, end_frame) sliced_list.append(sliced_rec) else: sliced_list.append(rec) @@ -319,7 +373,8 @@ def cut_concatenated_recording(concat_rec, cutout=np.inf): concat_sliced = si.concatenate_recordings(sliced_list) return concat_sliced -def split_concatenated_sorting(sorting_path, path_suffix='sorter_output'): + +def split_concatenated_sorting(sorting_path, path_suffix="sorter_output"): """ Function that takes the path of concatenated sorting and returns a SegmentSorting based on the durations of the individual recordings. @@ -336,103 +391,135 @@ def split_concatenated_sorting(sorting_path, path_suffix='sorter_output'): """ sorting_output = os.path.join(sorting_path, path_suffix) sorting = si.KiloSortSortingExtractor(sorting_output) - recording_path = os.path.join(sorting_path, 'spikeinterface_recording.json') + recording_path = os.path.join(sorting_path, "spikeinterface_recording.json") concat_rec = si.load_extractor(recording_path, base_folder=True) cleaned_sorting = si.remove_excess_spikes(sorting, concat_rec) cleaned_sorting.register_recording(concat_rec) segment_sorting = si.SplitSegmentSorting(cleaned_sorting, concat_rec) - + return segment_sorting, concat_rec -def save_split_sorting(seg_sorting, subfolder='segment_', keep_unit_ids=None, cutout=np.inf): + +def save_split_sorting( + seg_sorting, subfolder="segment_", keep_unit_ids=None, cutout=np.inf +): N_segments = seg_sorting.get_num_segments() if len(seg_sorting.get_unit_ids()) > 0: for seg_id in range(N_segments): seg = si.SelectSegmentSorting(seg_sorting, seg_id) if keep_unit_ids is not None: - seg = seg.select_units(keep_unit_ids) # ,renamed_unit_ids=list(range(len(keep_unit_ids))) - + seg = seg.select_units( + keep_unit_ids + ) # ,renamed_unit_ids=list(range(len(keep_unit_ids))) + spikes = seg.to_spike_vector() - duration = np.ceil(spikes['sample_index'].max()/seg.get_sampling_frequency()) - + duration = np.ceil( + spikes["sample_index"].max() / seg.get_sampling_frequency() + ) + if cutout < duration: - end_frame = spikes['sample_index'].max() + 1 + end_frame = spikes["sample_index"].max() + 1 start_frame = end_frame - cutout * seg.get_sampling_frequency() seg = seg.frame_slice(start_frame, end_frame) - - #spike_vector = seg.to_spike_vector(concatenated=True) #Removes original unit IDs - save_path = os.path.join(seg_sorting._annotations['phy_folder'], subfolder + str(seg_id)) + + # spike_vector = seg.to_spike_vector(concatenated=True) #Removes original unit IDs + save_path = os.path.join( + seg_sorting._annotations["phy_folder"], subfolder + str(seg_id) + ) Path(save_path).mkdir(exist_ok=True) - spike_times_path = os.path.join(save_path, 'spike_times.npy') - spike_templates_path = os.path.join(save_path, 'spike_templates.npy') - template_mat_path = os.path.join(seg_sorting._annotations['phy_folder'], 'qc_output','templates.npy') + spike_times_path = os.path.join(save_path, "spike_times.npy") + spike_templates_path = os.path.join(save_path, "spike_templates.npy") + template_mat_path = os.path.join( + seg_sorting._annotations["phy_folder"], "qc_output", "templates.npy" + ) if not os.path.exists(template_mat_path): - template_mat_path = os.path.join(seg_sorting._annotations['phy_folder'], 'templates.npy') #In case bc output was not exported - - channel_pos_path = os.path.join(seg_sorting._annotations['phy_folder'], 'channel_positions.npy') - params_pos_path = os.path.join(seg_sorting._annotations['phy_folder'], 'params.py') - np.save(spike_times_path, seg.get_all_spike_trains()[0][0])#spike_vector['sample_index']) - np.save(spike_templates_path, seg.get_all_spike_trains()[0][1])#spike_vector['unit_index']) + template_mat_path = os.path.join( + seg_sorting._annotations["phy_folder"], "templates.npy" + ) # In case bc output was not exported + + channel_pos_path = os.path.join( + seg_sorting._annotations["phy_folder"], "channel_positions.npy" + ) + params_pos_path = os.path.join( + seg_sorting._annotations["phy_folder"], "params.py" + ) + np.save( + spike_times_path, seg.get_all_spike_trains()[0][0] + ) # spike_vector['sample_index']) + np.save( + spike_templates_path, seg.get_all_spike_trains()[0][1] + ) # spike_vector['unit_index']) shutil.copy(template_mat_path, save_path) shutil.copy(channel_pos_path, save_path) shutil.copy(params_pos_path, save_path) + def find_saturated_channels(rec_list, threshold=0): """ Function that creates output folder if it does not exist, sorts the recording using the specified sorter - and clears up large files afterwards. + and clears up large files afterwards. Arguments ---------- rec_list: List of MaxwellRecordingExtractor objects. List of (potentially to be concatenated) recordings to be checked for saturated channels. threshold: float - Maximum ratio of saturated signal for the channel to still be accepted as non-saturated. + Maximum ratio of saturated signal for the channel to still be accepted as non-saturated. Returns ---------- saturated_count: np.array - Number of recordings in which the saturation threshold was crossed (channel was considered to be saturated). Values go from 0 to len(rec_list). + Number of recordings in which the saturation threshold was crossed (channel was considered to be saturated). Values go from 0 to len(rec_list). """ saturated_count = np.zeros((rec_list[0].get_num_channels())) - + for i in range(0, len(rec_list)): - random_data = si.get_random_data_chunks(rec_list[i], num_chunks_per_segment = int((rec_list[i].get_total_duration()/60))) - saturated = (np.sum((random_data == 0).astype("int16") + (random_data == 1023).astype("int16"),axis=0)) / random_data.shape[0] + random_data = si.get_random_data_chunks( + rec_list[i], + num_chunks_per_segment=int((rec_list[i].get_total_duration() / 60)), + ) + saturated = ( + np.sum( + (random_data == 0).astype("int16") + + (random_data == 1023).astype("int16"), + axis=0, + ) + ) / random_data.shape[0] saturated_count += saturated > threshold return saturated_count + def get_stream_ids(rec_path): h5 = h5py.File(rec_path) - stream_ids = list(h5['wells'].keys()) + stream_ids = list(h5["wells"].keys()) return stream_ids + def get_recording_path(sort_or_rec): start_dict = sort_or_rec - while 'file_path' not in start_dict._kwargs.keys(): - if '_recording' in vars(start_dict) and start_dict._recording is not None: + while "file_path" not in start_dict._kwargs.keys(): + if "_recording" in vars(start_dict) and start_dict._recording is not None: start_dict = start_dict._recording - elif 'sorting' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['sorting'] - elif 'recording' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['recording'] - elif 'recording_or_recording_list' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['recording_or_recording_list'] - elif 'parent_recording' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['parent_recording'] - elif 'recording_list' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['recording_list'] + elif "sorting" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["sorting"] + elif "recording" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["recording"] + elif "recording_or_recording_list" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["recording_or_recording_list"] + elif "parent_recording" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["parent_recording"] + elif "recording_list" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["recording_list"] else: - print('Could not find recording path') + print("Could not find recording path") file_path = [] break try: start_dict = start_dict[0] - + except Exception as e: - continue - - file_path = start_dict._kwargs['file_path'] + continue + + file_path = start_dict._kwargs["file_path"] - - return file_path \ No newline at end of file + return file_path diff --git a/axon_tracking/quantification.py b/axon_tracking/quantification.py index bfbcb73..0b187c6 100644 --- a/axon_tracking/quantification.py +++ b/axon_tracking/quantification.py @@ -3,25 +3,28 @@ from scipy.spatial.distance import pdist, cdist import scipy.stats as stats import warnings + sys.path.append("/home/phornauer/Git/axon_tracking/") import axon_tracking.skeletonization as skel -def get_sliding_window_velocity(path_list,params,window_size=5,min_r2=0.95): +def get_sliding_window_velocity(path_list, params, window_size=5, min_r2=0.95): warnings.simplefilter("ignore") all_vels = [] - num_vals = [] #count number of values for each velocity calc + num_vals = [] # count number of values for each velocity calc target_point_list = [] for path in path_list: vels, r2s = [], [] - for i in range(path.shape[0]-window_size+1): - target_points = path[i+1:i+window_size,:] - if not any(np.array_equal(target_points,tp) for tp in target_point_list): - dists = cdist(path[i:i+1,:2],target_points[:,:2]) / 1000 - time = (target_points[:,2] - path[i,2]) / (params['sampling_rate']/1000) + for i in range(path.shape[0] - window_size + 1): + target_points = path[i + 1 : i + window_size, :] + if not any(np.array_equal(target_points, tp) for tp in target_point_list): + dists = cdist(path[i : i + 1, :2], target_points[:, :2]) / 1000 + time = (target_points[:, 2] - path[i, 2]) / ( + params["sampling_rate"] / 1000 + ) regressor = sklearn.linear_model.LinearRegression(fit_intercept=False) - vel_x = time.reshape(-1,1) - vel_y = dists.reshape(-1,1) + vel_x = time.reshape(-1, 1) + vel_y = dists.reshape(-1, 1) regressor.fit(vel_x, vel_y) y_pred = regressor.predict(vel_x) r2 = sklearn.metrics.r2_score(vel_y, y_pred) @@ -29,54 +32,64 @@ def get_sliding_window_velocity(path_list,params,window_size=5,min_r2=0.95): vels.append(regressor.coef_[0][0]) r2s.append(r2) target_point_list.append(target_points) - - filtered_vel = np.mean(np.abs([vels[x] for x in range(len(vels)) if r2s[x] > min_r2])) + + filtered_vel = np.mean( + np.abs([vels[x] for x in range(len(vels)) if r2s[x] > min_r2]) + ) if not np.isnan(filtered_vel): all_vels.append(filtered_vel) num_vals.append(len(vels)) - - mean_vel = np.average(all_vels,weights=num_vals) - variance = np.average((all_vels - mean_vel)**2, weights=num_vals) + + mean_vel = np.average(all_vels, weights=num_vals) + variance = np.average((all_vels - mean_vel) ** 2, weights=num_vals) std_vel = np.sqrt(variance) return mean_vel, std_vel + def get_simple_template_size(branches): template_size = np.unique(np.concatenate(branches)).shape[0] return template_size + def get_branch_point_count(skeleton): branch_points = len(skeleton.branches()) return branch_points + def get_branch_point_dists(skeleton): ais = get_ais(skeleton) - branch_points = skeleton.vertices[skeleton.branches(),:2] - dists = cdist(ais,branch_points) / 1000 + branch_points = skeleton.vertices[skeleton.branches(), :2] + dists = cdist(ais, branch_points) / 1000 return dists + def get_branch_lengths(branches): lengths = [] for branch in branches: - branch_diff = np.diff(branch,axis=0) - dist = np.sqrt(branch_diff[:,0]**2 + branch_diff[:,1]**2) / 1000 + branch_diff = np.diff(branch, axis=0) + dist = np.sqrt(branch_diff[:, 0] ** 2 + branch_diff[:, 1] ** 2) / 1000 lengths.append(np.sum(dist)) return lengths + def get_longest_path(skeleton): longest_path = np.max([len(x) for x in skeleton.paths()]) return longest_path + def get_projection_dists(skeleton): ais = get_ais(skeleton) - terminals = skeleton.vertices[skeleton.terminals(),:2] - dists = cdist(ais,terminals) / 1000 + terminals = skeleton.vertices[skeleton.terminals(), :2] + dists = cdist(ais, terminals) / 1000 return dists + def get_ais(skeleton): - ais_idx = np.argmin(skeleton.vertices[:,2]) - ais = skeleton.vertices[ais_idx:ais_idx+1,:2] + ais_idx = np.argmin(skeleton.vertices[:, 2]) + ais = skeleton.vertices[ais_idx : ais_idx + 1, :2] return ais + def get_terminal_count(skeleton): terminal_count = len(skeleton.terminals()) - return terminal_count \ No newline at end of file + return terminal_count diff --git a/axon_tracking/skeletonization.py b/axon_tracking/skeletonization.py index d886a87..6918e4b 100755 --- a/axon_tracking/skeletonization.py +++ b/axon_tracking/skeletonization.py @@ -8,100 +8,132 @@ import scipy.stats as stats from scipy.spatial.distance import pdist, cdist -def full_skeletonization(root_path, stream_id, template_id, params, skel_params, qc_params): - template, template_save_file, noise = load_template_file(root_path, stream_id, template_id) - if np.mean(noise) > params['max_noise_level']: + +def full_skeletonization( + root_path, stream_id, template_id, params, skel_params, qc_params +): + template, template_save_file, noise = load_template_file( + root_path, stream_id, template_id + ) + if np.mean(noise) > params["max_noise_level"]: return [], [] - + temp_diff = np.diff(template) - capped_template, target_coor= localize_neurons(temp_diff, ms_cutout=params['ms_cutout']) - tmp_filt = nd.gaussian_filter(capped_template,sigma=1) - interp_temp = interpolate_template(tmp_filt,spacing=params['upsample']) + capped_template, target_coor = localize_neurons( + temp_diff, ms_cutout=params["ms_cutout"] + ) + tmp_filt = nd.gaussian_filter(capped_template, sigma=1) + interp_temp = interpolate_template(tmp_filt, spacing=params["upsample"]) th_template = threshold_template(interp_temp, noise, target_coor, params) - - t_cap = [0, th_template.shape[2]] #in samples - skels = skeletonize(th_template[:,:,t_cap[0]:t_cap[1]].astype('bool'), **skel_params) + + t_cap = [0, th_template.shape[2]] # in samples + skels = skeletonize( + th_template[:, :, t_cap[0] : t_cap[1]].astype("bool"), **skel_params + ) skeleton = kimimaro.join_close_components(skels[1], radius=5) skeleton = kimimaro.postprocess(skeleton, tick_threshold=5, dust_threshold=10) - + all_branches = branches_from_paths(skeleton) - scaled_qc_list, r2s, vels, lengths = perform_path_qc(all_branches, params,**qc_params) - + scaled_qc_list, r2s, vels, lengths = perform_path_qc( + all_branches, params, **qc_params + ) + qc_skel_list = [cv.Skeleton.from_path(x) for x in scaled_qc_list] qc_skeleton = cv.Skeleton.simple_merge(qc_skel_list) - qc_skeleton = kimimaro.postprocess(qc_skeleton, dust_threshold=10, tick_threshold=10) - qc_skeleton = kimimaro.join_close_components(qc_skeleton,radius=200) + qc_skeleton = kimimaro.postprocess( + qc_skeleton, dust_threshold=10, tick_threshold=10 + ) + qc_skeleton = kimimaro.join_close_components(qc_skeleton, radius=200) qc_skeleton = kimimaro.postprocess(qc_skeleton, dust_threshold=0, tick_threshold=0) - + all_branches = branches_from_paths(qc_skeleton) - - scaled_qc_list, r2s, full_vels, lengths = perform_path_qc(all_branches, params,**qc_params) - + + scaled_qc_list, r2s, full_vels, lengths = perform_path_qc( + all_branches, params, **qc_params + ) + return qc_skeleton, scaled_qc_list def load_template_file(root_path, stream_id, template_id): - template_save_file = os.path.join(root_path, stream_id, 'sorter_output', 'templates', str(template_id) + '.npy') #'sorter_output', - noise_save_file = os.path.join(root_path, stream_id, 'templates', str(template_id) + '_noise.npy') - template = np.load(template_save_file).astype('float64') - if False: #os.path.exists(noise_save_file): - noise = np.load(noise_save_file).astype('float64') + template_save_file = os.path.join( + root_path, stream_id, "sorter_output", "templates", str(template_id) + ".npy" + ) #'sorter_output', + noise_save_file = os.path.join( + root_path, stream_id, "templates", str(template_id) + "_noise.npy" + ) + template = np.load(template_save_file).astype("float64") + if False: # os.path.exists(noise_save_file): + noise = np.load(noise_save_file).astype("float64") else: - #print('No noise file found, inferring from template') + # print('No noise file found, inferring from template') noise = generate_noise_matrix(template) return template, template_save_file, noise -def localize_neurons(input_mat, ms_cutout, min_distance=5, threshold_rel=0.1, num_peaks=3, buffer_frames=2, ms_peak_cutout = 0.5): - #local_max = peak_local_max(np.abs(input_mat), min_distance=min_distance, threshold_rel=threshold_rel, num_peaks=num_peaks) - - #cutout_ratio = (ms_cutout[0]/np.sum(ms_cutout)) - #peak_range = [(cutout_ratio-(ms_peak_cutout/np.sum(ms_cutout))),(cutout_ratio+(ms_peak_cutout/np.sum(ms_cutout)))] - #peak_range = np.round(np.array(peak_range)*input_mat.shape[2]) +def localize_neurons( + input_mat, + ms_cutout, + min_distance=5, + threshold_rel=0.1, + num_peaks=3, + buffer_frames=2, + ms_peak_cutout=0.5, +): + + # local_max = peak_local_max(np.abs(input_mat), min_distance=min_distance, threshold_rel=threshold_rel, num_peaks=num_peaks) + + # cutout_ratio = (ms_cutout[0]/np.sum(ms_cutout)) + # peak_range = [(cutout_ratio-(ms_peak_cutout/np.sum(ms_cutout))),(cutout_ratio+(ms_peak_cutout/np.sum(ms_cutout)))] + # peak_range = np.round(np.array(peak_range)*input_mat.shape[2]) - #target_coor = local_max[(local_max[:,2] >= peak_range[0]) & (local_max[:,2] <= peak_range[1]),:].astype("int16") - - #if len(target_coor) > 0: - target_coor = list(np.unravel_index(np.argmax(-input_mat),input_mat.shape)) - capped_matrix = input_mat[:,:,(target_coor[2] - buffer_frames):] + # target_coor = local_max[(local_max[:,2] >= peak_range[0]) & (local_max[:,2] <= peak_range[1]),:].astype("int16") + + # if len(target_coor) > 0: + target_coor = list(np.unravel_index(np.argmax(-input_mat), input_mat.shape)) + capped_matrix = input_mat[:, :, (target_coor[2] - buffer_frames) :] target_coor[2] = buffer_frames - #else: + # else: # capped_matrix = input_mat # target_coor=[[0, 0, 0]] - #post_coor = local_max[local_max[:,2] > peak_range[1],:].astype("int16") - #if len(post_coor) > 0: #Check if postsynaptic target was detected + # post_coor = local_max[local_max[:,2] > peak_range[1],:].astype("int16") + # if len(post_coor) > 0: #Check if postsynaptic target was detected # post_coor[0][2] = post_coor[0][2] - target_coor[0][2] - - - return capped_matrix, target_coor#, post_coor -def generate_noise_matrix(template, noise=[], mode='mad'): + return capped_matrix, target_coor # , post_coor + + +def generate_noise_matrix(template, noise=[], mode="mad"): if not noise: - if mode == 'mad': - noise = stats.median_abs_deviation(template,axis=2) - elif mode == 'sd': - noise = np.std(template,axis=2) - + if mode == "mad": + noise = stats.median_abs_deviation(template, axis=2) + elif mode == "sd": + noise = np.std(template, axis=2) + noise_matrix = noise[:, :, np.newaxis] - + return noise_matrix + def threshold_template(template, noise, target_coor, params): - if params['noise_threshold']: - noise_th = template < (params['noise_threshold'] * noise[:,:,:template.shape[2]]) + if params["noise_threshold"]: + noise_th = template < ( + params["noise_threshold"] * noise[:, :, : template.shape[2]] + ) else: - noise_th = np.full_like(template,True) - abs_th = template < params['abs_threshold'] + noise_th = np.full_like(template, True) + abs_th = template < params["abs_threshold"] - #r = int((((template.shape[2] / params['sampling_rate']) * params['max_velocity']) * 1000000) / 17.5) - #velocity_th = cone(template.shape, r, apex=tuple(target_coor[0])) + # r = int((((template.shape[2] / params['sampling_rate']) * params['max_velocity']) * 1000000) / 17.5) + # velocity_th = cone(template.shape, r, apex=tuple(target_coor[0])) velocity_th = valid_latency_map(template, target_coor, params) th_template = noise_th * abs_th * velocity_th return th_template -def interp_max(x,spacing): + +def interp_max(x, spacing): if len(x) == 1: interp_max = spacing elif spacing == 1: @@ -110,53 +142,77 @@ def interp_max(x,spacing): interp_max = x[-1] return interp_max -def interpolate_template(template, spacing=[1,1,0.2], template_path = [], overwrite=False): + +def interpolate_template( + template, spacing=[1, 1, 0.2], template_path=[], overwrite=False +): if template_path: - split_path = template_path.split(sep='/') - split_path[-1] = 'interp_' + split_path[-1] - interp_tmp_path = '/'.join(split_path) + split_path = template_path.split(sep="/") + split_path[-1] = "interp_" + split_path[-1] + interp_tmp_path = "/".join(split_path) else: - interp_tmp_path = '' - + interp_tmp_path = "" + if os.path.exists(interp_tmp_path) and not overwrite: - interp_template = np.load(interp_tmp_path).astype('float64') + interp_template = np.load(interp_tmp_path).astype("float64") else: x, y, z = [np.arange(template.shape[k]) for k in range(3)] f = RegularGridInterpolator((x, y, z), template) - #new_grid = np.mgrid[0:x[-1]:spacing[0], 0:y[-1]:spacing[1], 0:z[-1]+1:spacing[2]] - new_grid = np.mgrid[0:interp_max(x,spacing[0]):spacing[0], 0:interp_max(y,spacing[1]):spacing[1], 0:interp_max(z,spacing[2]):spacing[2]] - - new_grid = np.moveaxis(new_grid, (0, 1, 2, 3), (3, 0, 1, 2)) # reorder axes for evaluation + # new_grid = np.mgrid[0:x[-1]:spacing[0], 0:y[-1]:spacing[1], 0:z[-1]+1:spacing[2]] + new_grid = np.mgrid[ + 0 : interp_max(x, spacing[0]) : spacing[0], + 0 : interp_max(y, spacing[1]) : spacing[1], + 0 : interp_max(z, spacing[2]) : spacing[2], + ] + + new_grid = np.moveaxis( + new_grid, (0, 1, 2, 3), (3, 0, 1, 2) + ) # reorder axes for evaluation interp_template = f(new_grid) if template_path: - np.save(interp_tmp_path, interp_template.astype('float32')) - + np.save(interp_tmp_path, interp_template.astype("float32")) + return interp_template + def valid_latency_map(template, start, params): - indices_array = np.indices(template.shape) * params['el_spacing'] #convert to (um) - distances = np.sqrt((indices_array[0] - start[0]*params['el_spacing'])**2 + (indices_array[1] - start[1]*params['el_spacing'])**2) / 1000000 #convert to (s) + indices_array = np.indices(template.shape) * params["el_spacing"] # convert to (um) + distances = ( + np.sqrt( + (indices_array[0] - start[0] * params["el_spacing"]) ** 2 + + (indices_array[1] - start[1] * params["el_spacing"]) ** 2 + ) + / 1000000 + ) # convert to (s) th_mat = np.zeros(distances.shape) for z in range(distances.shape[2]): - th_mat[:, :, z] = (params['max_velocity'] / params['sampling_rate']) * (z+2) + th_mat[:, :, z] = (params["max_velocity"] / params["sampling_rate"]) * (z + 2) passed = distances <= th_mat return passed + def cone(matrix_shape, r, apex=[]): if not apex: apex = np.empty(3) - apex[0], apex[1], apex[2] = np.floor(matrix_shape[0]/2), np.floor(matrix_shape[1]/2), 0 - - x, y, z = np.ogrid[:matrix_shape[0], :matrix_shape[1], :matrix_shape[2]] - cone_equation = (x - apex[0])**2 + (y - apex[1])**2 <= r**2 * (1 - (z - apex[2])/matrix_shape[2])**2 + apex[0], apex[1], apex[2] = ( + np.floor(matrix_shape[0] / 2), + np.floor(matrix_shape[1] / 2), + 0, + ) + + x, y, z = np.ogrid[: matrix_shape[0], : matrix_shape[1], : matrix_shape[2]] + cone_equation = (x - apex[0]) ** 2 + (y - apex[1]) ** 2 <= r**2 * ( + 1 - (z - apex[2]) / matrix_shape[2] + ) ** 2 cone_matrix = np.zeros(matrix_shape, dtype=bool) cone_matrix[cone_equation] = True return cone_matrix -def generate_dilation_structure(max_t, max_r, spacing=1/3, sampling_rate=20000): + +def generate_dilation_structure(max_t, max_r, spacing=1 / 3, sampling_rate=20000): """ max_t: numeric Maximum time [us] to detect a peak from a previous peak @@ -165,79 +221,112 @@ def generate_dilation_structure(max_t, max_r, spacing=1/3, sampling_rate=20000): spacing: numeric Spacing of the interpolation (if performed before the dilation) """ - el_dist = params['el_spacing'] - frame_time = (1000000/sampling_rate) * spacing #Assumes 20k sampling rate - t = np.ceil(max_t / frame_time).astype('int16') - r = np.ceil(max_r / el_dist).astype('int16') - d = (2*r + 1).astype('int16') - - cone_matrix = cone((d,d,t), r) - #x, y, z = np.ogrid[:d, :d, :t] - #cone_equation = (x - (r))**2 + (y - (r))**2 <= r**2 * (1 - (z - 0)/t)**2 - #cone_matrix = np.zeros((d,d,t), dtype=bool) - #cone_matrix[cone_equation] = True - - structure = cone_matrix[:,:,::-1] - structure_base = np.full((structure.shape[0],structure.shape[1],structure.shape[2]),False) - structure_init = np.full((structure.shape[0],structure.shape[1],1),False) - structure_init[r,r,0] = True - full_structure = np.concatenate((structure_base, structure_init, structure),axis=2) + el_dist = params["el_spacing"] + frame_time = (1000000 / sampling_rate) * spacing # Assumes 20k sampling rate + t = np.ceil(max_t / frame_time).astype("int16") + r = np.ceil(max_r / el_dist).astype("int16") + d = (2 * r + 1).astype("int16") + + cone_matrix = cone((d, d, t), r) + # x, y, z = np.ogrid[:d, :d, :t] + # cone_equation = (x - (r))**2 + (y - (r))**2 <= r**2 * (1 - (z - 0)/t)**2 + # cone_matrix = np.zeros((d,d,t), dtype=bool) + # cone_matrix[cone_equation] = True + + structure = cone_matrix[:, :, ::-1] + structure_base = np.full( + (structure.shape[0], structure.shape[1], structure.shape[2]), False + ) + structure_init = np.full((structure.shape[0], structure.shape[1], 1), False) + structure_init[r, r, 0] = True + full_structure = np.concatenate((structure_base, structure_init, structure), axis=2) return full_structure -def iterative_dilation(template, r_dilation=2, init_th=-10, min_th=-1, filter_footprint=(3,3,3), use_derivative=True): + +def iterative_dilation( + template, + r_dilation=2, + init_th=-10, + min_th=-1, + filter_footprint=(3, 3, 3), + use_derivative=True, +): if use_derivative: template = np.diff(template) structure = generate_dilation_structure(r_dilation) - m_init = template < init_th #Detection of initial seeds/definitive peaks - mask = template < min_th #Mask indicating potential peak locations + m_init = template < init_th # Detection of initial seeds/definitive peaks + mask = template < min_th # Mask indicating potential peak locations dilated = nd.binary_dilation(m_init, structure=structure, iterations=0, mask=mask) - #filtered = nd.median_filter(dilated,size=filter_size) - #filtered = nd.median_filter(dilated,footprint=ball(1)) + # filtered = nd.median_filter(dilated,size=filter_size) + # filtered = nd.median_filter(dilated,footprint=ball(1)) if filter_footprint is not None: - dilated = nd.median_filter(dilated,footprint=filter_footprint) + dilated = nd.median_filter(dilated, footprint=filter_footprint) return dilated -def skeletonize(input_matrix, scale=2, const=50, pdrf_exponent=4, pdrf_scale=10000, dust_threshold=0, anisotropy=(17.5,17.5,50.0), tick_threshold=10, n_jobs=16,ais=[]): + +def skeletonize( + input_matrix, + scale=2, + const=50, + pdrf_exponent=4, + pdrf_scale=10000, + dust_threshold=0, + anisotropy=(17.5, 17.5, 50.0), + tick_threshold=10, + n_jobs=16, + ais=[], +): skels = kimimaro.skeletonize( - input_matrix, - teasar_params={ - 'scale': scale, - 'const': const, # physical units - 'pdrf_exponent': pdrf_exponent, - 'pdrf_scale': pdrf_scale, - "soma_acceptance_threshold": 35, # physical units - "soma_detection_threshold": 20, # physical units - "soma_invalidation_const": 30, # physical units - "soma_invalidation_scale": 2, - - }, - dust_threshold = dust_threshold, - parallel = n_jobs) - #extra_targets_before=ais + input_matrix, + teasar_params={ + "scale": scale, + "const": const, # physical units + "pdrf_exponent": pdrf_exponent, + "pdrf_scale": pdrf_scale, + "soma_acceptance_threshold": 35, # physical units + "soma_detection_threshold": 20, # physical units + "soma_invalidation_const": 30, # physical units + "soma_invalidation_scale": 2, + }, + dust_threshold=dust_threshold, + parallel=n_jobs, + ) + # extra_targets_before=ais return skels -def perform_path_qc(paths, params, window_size=7, max_duplicate_ratio=0.5, min_r2=0.9, vel_range = [0.4, 1], min_length=10): + +def perform_path_qc( + paths, + params, + window_size=7, + max_duplicate_ratio=0.5, + min_r2=0.9, + vel_range=[0.4, 1], + min_length=10, +): if np.max(np.concatenate(paths)) < 220: - scaled_paths = scale_path_coordinates(paths,params) + scaled_paths = scale_path_coordinates(paths, params) else: scaled_paths = paths good_path_list, r2s, vels, lengths = [], [], [], [] path_list = [] for path in scaled_paths: if len(path) > window_size: - path, inflection_points = split_path(path,window_size=window_size,max_duplicate_ratio=max_duplicate_ratio) + path, inflection_points = split_path( + path, window_size=window_size, max_duplicate_ratio=max_duplicate_ratio + ) path = [p for p in path if len(p) > min_length] - if len(path) <1: + if len(path) < 1: continue if type(path) is not list: path = [path] good_path_list = good_path_list + path - #good_path_list.append(path) - #print(good_path_list) + # good_path_list.append(path) + # print(good_path_list) for path in good_path_list: if path.shape[0] > 3: vel, r2 = calculate_path_velocity(path, params) @@ -246,8 +335,12 @@ def perform_path_qc(paths, params, window_size=7, max_duplicate_ratio=0.5, min_r r2s.append(r2) vels.append(vel) lengths.append(len(path)) - - good_vel = np.where((np.array(r2s)>min_r2) & (np.array(vels)vel_range[0])) + + good_vel = np.where( + (np.array(r2s) > min_r2) + & (np.array(vels) < vel_range[1]) + & (np.array(vels) > vel_range[0]) + ) good_vel = good_vel | np.isnan(r2) qc_list = [good_path_list[x] for x in good_vel[0]] r2s = [r2s[x] for x in good_vel[0]] @@ -255,71 +348,91 @@ def perform_path_qc(paths, params, window_size=7, max_duplicate_ratio=0.5, min_r lengths = [lengths[x] for x in good_vel[0]] return qc_list, r2s, vels, lengths -def split_path(path, window_size = 7, max_duplicate_ratio=0.3): - smoothed_data = np.convolve(path[:,2], np.ones(window_size)/window_size,mode='valid') + +def split_path(path, window_size=7, max_duplicate_ratio=0.3): + smoothed_data = np.convolve( + path[:, 2], np.ones(window_size) / window_size, mode="valid" + ) gradient = np.gradient(smoothed_data) inflection_points = np.where(np.diff(np.sign(gradient)))[0] - inflection_points = (inflection_points + np.floor(window_size/2)).astype('int') - inflection_points = np.delete(inflection_points, [np.where(np.abs(np.diff(inflection_points)) < 2)]) + inflection_points = (inflection_points + np.floor(window_size / 2)).astype("int") + inflection_points = np.delete( + inflection_points, [np.where(np.abs(np.diff(inflection_points)) < 2)] + ) split_points = [0] + list(inflection_points) + [path.shape[0]] - #print(split_points) - split_paths = [path[split_points[s]:split_points[s+1],:] for s in range(len(split_points)-1)] - qc_paths = remove_circulating_paths(split_paths, max_duplicate_ratio=max_duplicate_ratio) - + # print(split_points) + split_paths = [ + path[split_points[s] : split_points[s + 1], :] + for s in range(len(split_points) - 1) + ] + qc_paths = remove_circulating_paths( + split_paths, max_duplicate_ratio=max_duplicate_ratio + ) + return qc_paths, inflection_points + def branches_from_paths(skeleton): branch_points = skeleton.branches() path_list = skeleton.paths() all_branches = [] for path in path_list: - path_idx, _ = np.where((path[:, None] == skeleton.vertices[branch_points]).all(axis=2)) + path_idx, _ = np.where( + (path[:, None] == skeleton.vertices[branch_points]).all(axis=2) + ) branch_idx = [] for p in path_idx: - #check_idx = range(p-1,p+2) - dists = pdist(path[p-1:p+2,:]) + # check_idx = range(p-1,p+2) + dists = pdist(path[p - 1 : p + 2, :]) branch_idx = [0] if dists[0] > dists[2]: branch_idx.append(p) else: - branch_idx.append(p+1) + branch_idx.append(p + 1) if len(branch_idx) > 0: branch_idx.append(path.shape[0]) - new_branches = [path[branch_idx[x]:branch_idx[x+1],:] for x in range(len(branch_idx)-1)] + new_branches = [ + path[branch_idx[x] : branch_idx[x + 1], :] + for x in range(len(branch_idx) - 1) + ] all_branches = all_branches + new_branches - + unique_branches = [] for arr in all_branches: -# Check if the array is not already in unique_arrays + # Check if the array is not already in unique_arrays branch_list = [] if not any(np.array_equal(arr, unique_arr) for unique_arr in unique_branches): - #unique_branches.append(arr) + # unique_branches.append(arr) branch_list.append(arr) unique_branches = unique_branches + branch_list - unique_branches = [x for x in unique_branches if x.shape[0]>1] + unique_branches = [x for x in unique_branches if x.shape[0] > 1] return unique_branches + def remove_circulating_paths(path_list, max_duplicate_ratio=0.3): """ - Removes paths that contain too many duplicate x,y coordinates, expressed as ratio. + Removes paths that contain too many duplicate x,y coordinates, expressed as ratio. """ indices_to_remove = [] for p, path in enumerate(path_list): - points = path[:,:2] - duplicate_ratio = 1 - (len(np.unique(points,axis=0))/len(points)) - - if (duplicate_ratio > max_duplicate_ratio): + points = path[:, :2] + duplicate_ratio = 1 - (len(np.unique(points, axis=0)) / len(points)) + + if duplicate_ratio > max_duplicate_ratio: indices_to_remove.append(p) - - clean_path_list = [path_list[i] for i in range(len(path_list)) if i not in indices_to_remove] - #print(len(path_list) - len(clean_path_list)) + + clean_path_list = [ + path_list[i] for i in range(len(path_list)) if i not in indices_to_remove + ] + # print(len(path_list) - len(clean_path_list)) return clean_path_list + def calculate_path_velocity(path, params): - path_diff = np.diff(path,axis=0) - #print(path_diff.shape) - dist = np.sqrt(path_diff[:,0]**2 + path_diff[:,1]**2) / 1000 - time = (np.cumsum(np.abs(path_diff[:,2])) / params['sampling_rate']) * 1000 + path_diff = np.diff(path, axis=0) + # print(path_diff.shape) + dist = np.sqrt(path_diff[:, 0] ** 2 + path_diff[:, 1] ** 2) / 1000 + time = (np.cumsum(np.abs(path_diff[:, 2])) / params["sampling_rate"]) * 1000 regressor = sklearn.linear_model.LinearRegression(fit_intercept=False) vel_y = np.cumsum(dist).reshape(-1, 1) vel_x = time.reshape(-1, 1) @@ -329,21 +442,27 @@ def calculate_path_velocity(path, params): return regressor.coef_[0][0], r2 - -def scale_path_coordinates(path_list,params): - scaled_paths = ([np.concatenate((path[:,:2]*params['el_spacing'], path[:,2:]),axis=1) for path in path_list]) +def scale_path_coordinates(path_list, params): + scaled_paths = [ + np.concatenate((path[:, :2] * params["el_spacing"], path[:, 2:]), axis=1) + for path in path_list + ] return scaled_paths -def unscale_path_coordinates(path_list,params): - unscaled_paths = ([np.concatenate((path[:,:2]/params['el_spacing'], path[:,2:]),axis=1) for path in path_list]) + +def unscale_path_coordinates(path_list, params): + unscaled_paths = [ + np.concatenate((path[:, :2] / params["el_spacing"], path[:, 2:]), axis=1) + for path in path_list + ] return unscaled_paths + def path_to_vertices(path_list, params, unscale=True): if unscale: path_list = unscale_path_coordinates(path_list, params) - + vertices = np.concatenate(path_list) sorted_indices = np.argsort(vertices[:, 2]) - sorted_vertices = vertices[sorted_indices,:] + sorted_vertices = vertices[sorted_indices, :] return sorted_vertices - \ No newline at end of file diff --git a/axon_tracking/spike_sorting.py b/axon_tracking/spike_sorting.py index 8a601ae..baced3d 100644 --- a/axon_tracking/spike_sorting.py +++ b/axon_tracking/spike_sorting.py @@ -5,9 +5,17 @@ from tqdm import tqdm from glob import glob -def sort_recording_list(path_list, save_path_changes, sorter, sorter_params = dict(), clear_files=True, verbose=True): + +def sort_recording_list( + path_list, + save_path_changes, + sorter, + sorter_params=dict(), + clear_files=True, + verbose=True, +): """ - Function that iterates over a list of axon scans, finds common electrodes, concatenates and spike sorts the recording slices. + Function that iterates over a list of axon scans, finds common electrodes, concatenates and spike sorts the recording slices. Arguments ---------- @@ -27,30 +35,41 @@ def sort_recording_list(path_list, save_path_changes, sorter, sorter_params = di Returns ---------- - sorting_list: list of sorting objects + sorting_list: list of sorting objects Specific type depends on the sorter. - + """ - + sorting_list = [] - + for rec_path in tqdm(path_list, desc="Sorting recordings"): - + h5 = h5py.File(rec_path) - #Check that all wells are recorded throughout all recordings (should not fail) - stream_ids = list(h5['wells'].keys()) + # Check that all wells are recorded throughout all recordings (should not fail) + stream_ids = list(h5["wells"].keys()) save_root = convert_rec_path_to_save_path(rec_path, save_path_changes) - + for stream_id in tqdm(stream_ids, desc="Sorting wells"): - sorter_output_file = Path(os.path.join(save_root, stream_id, 'sorter_output', 'amplitudes.npy')) + sorter_output_file = Path( + os.path.join(save_root, stream_id, "sorter_output", "amplitudes.npy") + ) if not os.path.exists(sorter_output_file): - multirecording, common_el, pos = concatenate_recording_slices(rec_path, stream_id) - sorting = clean_sorting(multirecording, save_root, stream_id, sorter, sorter_params, clear_files=clear_files, verbose=verbose) + multirecording, common_el, pos = concatenate_recording_slices( + rec_path, stream_id + ) + sorting = clean_sorting( + multirecording, + save_root, + stream_id, + sorter, + sorter_params, + clear_files=clear_files, + verbose=verbose, + ) sorting_list.append(sorting) - - return sorting_list + return sorting_list def convert_rec_path_to_save_path(rec_path, save_path_changes): @@ -64,32 +83,33 @@ def convert_rec_path_to_save_path(rec_path, save_path_changes): save_path_changes: dict Dictionary containing keys 'pos' and 'vals' that indicate the changes to be made to the rec_path. Refer to the inidices after splitting the path by '/'. - + Returns ---------- save_path: str Root save path. Well ID will be appended during the sorting. """ - path_parts = rec_path.split('/') - for x,y in zip(save_path_changes['pos'], save_path_changes['vals']): + path_parts = rec_path.split("/") + for x, y in zip(save_path_changes["pos"], save_path_changes["vals"]): path_parts[x] = y - + save_path = os.path.join(*path_parts) - + return save_path + def find_common_electrodes(rec_path, stream_id): """ Function that returns the common electrodes of the successive axon scan recordings. - + Arguments ---------- rec_path: str Path to the axon scan file. stream_id: str Well ID in the format "well***"; Well 1 would be "well001", Well 20 would be "well020" - + Returns ---------- rec_names: list @@ -97,36 +117,36 @@ def find_common_electrodes(rec_path, stream_id): common_el: list List of electrodes that are present in all axon scan recordings. """ - - assert(os.path.exists(rec_path)) - + + assert os.path.exists(rec_path) + h5 = h5py.File(rec_path) - rec_names = list(h5['wells'][stream_id].keys()) + rec_names = list(h5["wells"][stream_id].keys()) pos = dict() - x, y = np.full([1,26400], np.nan), np.full([1,26400], np.nan) - + x, y = np.full([1, 26400], np.nan), np.full([1, 26400], np.nan) for i, rec_name in enumerate(rec_names): - #rec_name = 'rec' + '%0*d' % (4, rec_id) - rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id, rec_name=rec_name) + # rec_name = 'rec' + '%0*d' % (4, rec_id) + rec = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_name + ) rec_el = rec.get_property("contact_vector")["electrode"] - x[:,rec_el] = rec.get_property("contact_vector")['x'] - y[:,rec_el] = rec.get_property("contact_vector")['y'] - + x[:, rec_el] = rec.get_property("contact_vector")["x"] + y[:, rec_el] = rec.get_property("contact_vector")["y"] + if i == 0: common_el = rec_el else: common_el = list(set(common_el).intersection(rec_el)) - pos = {'x': x[0], 'y': y[0]} - - return rec_names, common_el, pos + pos = {"x": x[0], "y": y[0]} + return rec_names, common_el, pos def concatenate_recording_slices(rec_path, stream_id, center=True): """ - Function that centers and concatenates the recordings of an axon scan for all common electrodes. + Function that centers and concatenates the recordings of an axon scan for all common electrodes. Arguments ---------- @@ -134,7 +154,7 @@ def concatenate_recording_slices(rec_path, stream_id, center=True): Path to the axon scan file. stream_id: str Well ID in the format "well***"; Well 1 would be "well001", Well 20 would be "well020" - + Returns ---------- multirecording: ConcatenatedRecordingSlice @@ -143,56 +163,76 @@ def concatenate_recording_slices(rec_path, stream_id, center=True): rec_names, common_el, pos = find_common_electrodes(rec_path, stream_id) if len(rec_names) == 1: - rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id, rec_name=rec_names[0]) + rec = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_names[0] + ) return rec else: rec_list = [] - for rec_name in rec_names: - #rec_name = 'rec' + '%0*d' % (4, r) - rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id, rec_name=rec_name) - - ch_id = rec.get_property("contact_vector")['device_channel_indices'] + for rec_name in rec_names: + # rec_name = 'rec' + '%0*d' % (4, r) + rec = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_name + ) + + ch_id = rec.get_property("contact_vector")["device_channel_indices"] rec_el = rec.get_property("contact_vector")["electrode"] - + chan_idx = [np.where(rec_el == el)[0][0] for el in common_el] sel_channels = rec.get_channel_ids()[chan_idx] if center: - chunk_size = np.min([10000, rec.get_num_samples()]) - 100 #Fallback for ultra short recordings (too little activity) - rec = si.center(rec,chunk_size=chunk_size) - - rec_list.append(rec.channel_slice(sel_channels, renamed_channel_ids=list(range(len(chan_idx))))) - + chunk_size = ( + np.min([10000, rec.get_num_samples()]) - 100 + ) # Fallback for ultra short recordings (too little activity) + rec = si.center(rec, chunk_size=chunk_size) + + rec_list.append( + rec.channel_slice( + sel_channels, renamed_channel_ids=list(range(len(chan_idx))) + ) + ) + multirecording = si.concatenate_recordings(rec_list) - + return multirecording, common_el, pos + def intersect_and_concatenate_recording_list(rec_list): - assert(len(rec_list) > 1) + assert len(rec_list) > 1 rec_el_list = [] sliced_rec_list = [] - for i,rec in enumerate(rec_list): + for i, rec in enumerate(rec_list): rec_el = rec.get_property("contact_vector")["electrode"] rec_el_list.append(rec_el) if i == 0: common_el = rec_el else: common_el = list(set(common_el).intersection(rec_el)) - + for i, els in enumerate(rec_el_list): chan_idx = [np.where(els == el)[0][0] for el in common_el] channel_ids = rec_list[i].get_channel_ids()[chan_idx] - slice_rec = rec_list[i].channel_slice(channel_ids, renamed_channel_ids=list(range(len(channel_ids)))) - sliced_rec_list.append(slice_rec.astype('float32')) - + slice_rec = rec_list[i].channel_slice( + channel_ids, renamed_channel_ids=list(range(len(channel_ids))) + ) + sliced_rec_list.append(slice_rec.astype("float32")) + concatenated = si.concatenate_recordings(sliced_rec_list) return concatenated - -def clean_sorting(rec, save_root, stream_id, sorter, sorter_params = dict(), clear_files=True, verbose=True): +def clean_sorting( + rec, + save_root, + stream_id, + sorter, + sorter_params=dict(), + clear_files=True, + verbose=True, +): """ Function that creates output folder if it does not exist, sorts the recording using the specified sorter - and clears up large files afterwards. + and clears up large files afterwards. Arguments ---------- @@ -213,39 +253,49 @@ def clean_sorting(rec, save_root, stream_id, sorter, sorter_params = dict(), cle Returns ---------- - sorting: Sorting object + sorting: Sorting object Specific type depends on the sorter. """ - + output_folder = Path(os.path.join(save_root, stream_id)) - sorter_output_file = os.path.join(output_folder, 'sorter_output', 'amplitudes.npy') + sorter_output_file = os.path.join(output_folder, "sorter_output", "amplitudes.npy") sorting = [] # Creates output folder if sorting has not yet been done if os.path.exists(sorter_output_file): return sorting - elif (rec.get_total_duration() < 30): - full_output_folder = Path(os.path.join(output_folder, 'sorter_output')) + elif rec.get_total_duration() < 30: + full_output_folder = Path(os.path.join(output_folder, "sorter_output")) full_output_folder.mkdir(parents=True, exist_ok=True) - np.save(sorter_output_file, np.empty(0)) #Empty file to indicate a failed sorting for future loops + np.save( + sorter_output_file, np.empty(0) + ) # Empty file to indicate a failed sorting for future loops return sorting else: - #output_folder.mkdir(parents=True, exist_ok=True) - raw_file = os.path.join(output_folder, 'sorter_output', 'recording.dat') - wh_file = os.path.join(output_folder, 'sorter_output', 'temp_wh.dat') + # output_folder.mkdir(parents=True, exist_ok=True) + raw_file = os.path.join(output_folder, "sorter_output", "recording.dat") + wh_file = os.path.join(output_folder, "sorter_output", "temp_wh.dat") if verbose: - print(f"DURATION: {rec.get_num_frames() / rec.get_sampling_frequency()} s -- " - f"NUM. CHANNELS: {rec.get_num_channels()}") + print( + f"DURATION: {rec.get_num_frames() / rec.get_sampling_frequency()} s -- " + f"NUM. CHANNELS: {rec.get_num_channels()}" + ) # We use try/catch to not break loops when iterating over several sortings (e.g. when not all wells were recorded) try: t_start_sort = time.time() - sorting = si.run_sorter(sorter, rec, output_folder=output_folder, verbose=verbose, remove_existing_folder=True, - **sorter_params) + sorting = si.run_sorter( + sorter, + rec, + output_folder=output_folder, + verbose=verbose, + remove_existing_folder=True, + **sorter_params, + ) if verbose: print(f"\n\nSpike sorting elapsed time {time.time() - t_start_sort} s") - - #Making sure we clean up the largest temporary files + + # Making sure we clean up the largest temporary files if clear_files & os.path.exists(wh_file): os.remove(wh_file) if clear_files & os.path.exists(raw_file): @@ -256,9 +306,10 @@ def clean_sorting(rec, save_root, stream_id, sorter, sorter_params = dict(), cle os.remove(wh_file) if clear_files & os.path.exists(raw_file): os.remove(raw_file) - + return sorting + def generate_rec_list(path_parts): """ Function that takes a list of strings (path parts) and finds all recordings matching the path pattern, and returns the stream ids for the first recordings. @@ -276,42 +327,45 @@ def generate_rec_list(path_parts): List of stream_ids (wells) recorded from the first recording. """ path_pattern = os.path.join(*path_parts) - path_list = glob(path_pattern) + path_list = glob(path_pattern) h5 = h5py.File(path_list[-1]) - stream_ids = list(h5['wells'].keys()) + stream_ids = list(h5["wells"].keys()) path_list.sort() - - return path_list, stream_ids + + return path_list, stream_ids + def concatenate_recording_list(path_list, stream_id): well_recording_list = [] - for rec_path in path_list: #Iterate over recordings to be concatenated - try: # If not all wells were recorded, should be the only cause for an error - rec = si.MaxwellRecordingExtractor(rec_path,stream_id=stream_id) + for rec_path in path_list: # Iterate over recordings to be concatenated + try: # If not all wells were recorded, should be the only cause for an error + rec = si.MaxwellRecordingExtractor(rec_path, stream_id=stream_id) well_recording_list.append(rec) except Exception: continue - + if len(well_recording_list) == len(path_list): multirecording = si.concatenate_recordings(well_recording_list) else: - raise ValueError('Could not load all recordings!') - + raise ValueError("Could not load all recordings!") + saturated_count = find_saturated_channels(well_recording_list) - clean_multirecording = multirecording.remove_channels(multirecording.get_channel_ids()[saturated_count>0]) - - + clean_multirecording = multirecording.remove_channels( + multirecording.get_channel_ids()[saturated_count > 0] + ) + return clean_multirecording + def cut_concatenated_recording(concat_rec, cutout=np.inf): - rec_list = concat_rec._kwargs['recording_list'] + rec_list = concat_rec._kwargs["recording_list"] sliced_list = [] for rec in rec_list: duration = rec.get_total_duration() if cutout < duration: end_frame = rec.get_num_frames() start_frame = end_frame - cutout * rec.get_sampling_frequency() - sliced_rec = rec.frame_slice(start_frame,end_frame) + sliced_rec = rec.frame_slice(start_frame, end_frame) sliced_list.append(sliced_rec) else: sliced_list.append(rec) @@ -319,7 +373,8 @@ def cut_concatenated_recording(concat_rec, cutout=np.inf): concat_sliced = si.concatenate_recordings(sliced_list) return concat_sliced -def split_concatenated_sorting(sorting_path, path_suffix='sorter_output'): + +def split_concatenated_sorting(sorting_path, path_suffix="sorter_output"): """ Function that takes the path of concatenated sorting and returns a SegmentSorting based on the durations of the individual recordings. @@ -336,15 +391,18 @@ def split_concatenated_sorting(sorting_path, path_suffix='sorter_output'): """ sorting_output = os.path.join(sorting_path, path_suffix) sorting = si.KiloSortSortingExtractor(sorting_output) - recording_path = os.path.join(sorting_path, 'spikeinterface_recording.json') + recording_path = os.path.join(sorting_path, "spikeinterface_recording.json") concat_rec = si.load_extractor(recording_path, base_folder=True) cleaned_sorting = si.remove_excess_spikes(sorting, concat_rec) cleaned_sorting.register_recording(concat_rec) segment_sorting = si.SplitSegmentSorting(cleaned_sorting, concat_rec) - + return segment_sorting, concat_rec -def save_split_sorting(seg_sorting, subfolder='segment_', keep_unit_ids=None, cutout=[0, np.inf]): + +def save_split_sorting( + seg_sorting, subfolder="segment_", keep_unit_ids=None, cutout=[0, np.inf] +): """Saves the split sorting into subfolders for each segment in the phy format. Args: @@ -353,103 +411,132 @@ def save_split_sorting(seg_sorting, subfolder='segment_', keep_unit_ids=None, cu keep_unit_ids (list, optional): List of unit ids to be kept, e.g., as a QC result. Defaults to None, which uses all units. cutout (list or np.array, optional): Cutout in seconds to be kept (relevant for wash-in artefacts). Can be 2D if different cutouts should be used. Defaults to [0, np.inf], which uses the entire duration for each segment. """ - if len(cutout.shape) == 1: #If only one cutout is provided, we assume it applies to all segments - cutout = np.tile(cutout, (seg_sorting.get_num_segments(),1)) + if ( + len(cutout.shape) == 1 + ): # If only one cutout is provided, we assume it applies to all segments + cutout = np.tile(cutout, (seg_sorting.get_num_segments(), 1)) N_segments = seg_sorting.get_num_segments() if len(seg_sorting.get_unit_ids()) > 0: for seg_id in range(N_segments): seg = si.SelectSegmentSorting(seg_sorting, seg_id) if keep_unit_ids is not None: - seg = seg.select_units(np.squeeze(keep_unit_ids).tolist()) # ,renamed_unit_ids=list(range(len(keep_unit_ids))) - + seg = seg.select_units( + np.squeeze(keep_unit_ids).tolist() + ) # ,renamed_unit_ids=list(range(len(keep_unit_ids))) + spikes = seg.to_spike_vector() - #duration = np.ceil(spikes['sample_index'].max()/seg.get_sampling_frequency()) - + # duration = np.ceil(spikes['sample_index'].max()/seg.get_sampling_frequency()) + if cutout[seg_id][0] == 0 and cutout[seg_id][1] == np.inf: pass else: if cutout[seg_id][1] == np.inf: - end_frame = spikes['sample_index'].max() + 1 + end_frame = spikes["sample_index"].max() + 1 else: end_frame = cutout[seg_id][1] * seg.get_sampling_frequency() - + start_frame = cutout[seg_id][0] * seg.get_sampling_frequency() - + seg = seg.frame_slice(start_frame, end_frame) - - #spike_vector = seg.to_spike_vector(concatenated=True) #Removes original unit IDs - save_path = os.path.join(seg_sorting._annotations['phy_folder'], subfolder + str(seg_id)) + + # spike_vector = seg.to_spike_vector(concatenated=True) #Removes original unit IDs + save_path = os.path.join( + seg_sorting._annotations["phy_folder"], subfolder + str(seg_id) + ) Path(save_path).mkdir(exist_ok=True) - spike_times_path = os.path.join(save_path, 'spike_times.npy') - spike_templates_path = os.path.join(save_path, 'spike_templates.npy') - template_mat_path = os.path.join(seg_sorting._annotations['phy_folder'], 'qc_output','templates.npy') + spike_times_path = os.path.join(save_path, "spike_times.npy") + spike_templates_path = os.path.join(save_path, "spike_templates.npy") + template_mat_path = os.path.join( + seg_sorting._annotations["phy_folder"], "qc_output", "templates.npy" + ) if not os.path.exists(template_mat_path): - template_mat_path = os.path.join(seg_sorting._annotations['phy_folder'], 'templates.npy') #In case bc output was not exported - - channel_pos_path = os.path.join(seg_sorting._annotations['phy_folder'], 'channel_positions.npy') - params_pos_path = os.path.join(seg_sorting._annotations['phy_folder'], 'params.py') - np.save(spike_times_path, seg.get_all_spike_trains()[0][0])#spike_vector['sample_index']) - np.save(spike_templates_path, seg.get_all_spike_trains()[0][1])#spike_vector['unit_index']) + template_mat_path = os.path.join( + seg_sorting._annotations["phy_folder"], "templates.npy" + ) # In case bc output was not exported + + channel_pos_path = os.path.join( + seg_sorting._annotations["phy_folder"], "channel_positions.npy" + ) + params_pos_path = os.path.join( + seg_sorting._annotations["phy_folder"], "params.py" + ) + np.save( + spike_times_path, seg.get_all_spike_trains()[0][0] + ) # spike_vector['sample_index']) + np.save( + spike_templates_path, seg.get_all_spike_trains()[0][1] + ) # spike_vector['unit_index']) shutil.copy(template_mat_path, save_path) shutil.copy(channel_pos_path, save_path) shutil.copy(params_pos_path, save_path) + def find_saturated_channels(rec_list, threshold=0): """ Function that creates output folder if it does not exist, sorts the recording using the specified sorter - and clears up large files afterwards. + and clears up large files afterwards. Arguments ---------- rec_list: List of MaxwellRecordingExtractor objects. List of (potentially to be concatenated) recordings to be checked for saturated channels. threshold: float - Maximum ratio of saturated signal for the channel to still be accepted as non-saturated. + Maximum ratio of saturated signal for the channel to still be accepted as non-saturated. Returns ---------- saturated_count: np.array - Number of recordings in which the saturation threshold was crossed (channel was considered to be saturated). Values go from 0 to len(rec_list). + Number of recordings in which the saturation threshold was crossed (channel was considered to be saturated). Values go from 0 to len(rec_list). """ saturated_count = np.zeros((rec_list[0].get_num_channels())) - + for i in range(0, len(rec_list)): - random_data = si.get_random_data_chunks(rec_list[i], num_chunks_per_segment = int((rec_list[i].get_total_duration()/60))) - saturated = (np.sum((random_data == 0).astype("int16") + (random_data == 1023).astype("int16"),axis=0)) / random_data.shape[0] + random_data = si.get_random_data_chunks( + rec_list[i], + num_chunks_per_segment=int((rec_list[i].get_total_duration() / 60)), + ) + saturated = ( + np.sum( + (random_data == 0).astype("int16") + + (random_data == 1023).astype("int16"), + axis=0, + ) + ) / random_data.shape[0] saturated_count += saturated > threshold return saturated_count + def get_stream_ids(rec_path): h5 = h5py.File(rec_path) - stream_ids = list(h5['wells'].keys()) + stream_ids = list(h5["wells"].keys()) return stream_ids + def get_recording_path(sort_or_rec): start_dict = sort_or_rec - while 'file_path' not in start_dict._kwargs.keys(): - if '_recording' in vars(start_dict) and start_dict._recording is not None: + while "file_path" not in start_dict._kwargs.keys(): + if "_recording" in vars(start_dict) and start_dict._recording is not None: start_dict = start_dict._recording - elif 'sorting' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['sorting'] - elif 'recording' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['recording'] - elif 'recording_or_recording_list' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['recording_or_recording_list'] - elif 'parent_recording' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['parent_recording'] - elif 'recording_list' in start_dict._kwargs.keys(): - start_dict = start_dict._kwargs['recording_list'] + elif "sorting" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["sorting"] + elif "recording" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["recording"] + elif "recording_or_recording_list" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["recording_or_recording_list"] + elif "parent_recording" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["parent_recording"] + elif "recording_list" in start_dict._kwargs.keys(): + start_dict = start_dict._kwargs["recording_list"] else: - print('Could not find recording path') + print("Could not find recording path") file_path = [] break try: start_dict = start_dict[0] - + except Exception as e: - continue - - file_path = start_dict._kwargs['file_path'] + continue + + file_path = start_dict._kwargs["file_path"] - - return file_path \ No newline at end of file + return file_path diff --git a/axon_tracking/template_extraction.py b/axon_tracking/template_extraction.py index 44b827b..b704832 100755 --- a/axon_tracking/template_extraction.py +++ b/axon_tracking/template_extraction.py @@ -7,6 +7,7 @@ from axon_tracking import spike_sorting as ss + def extract_templates_from_sorting_dict(sorting_dict, qc_params={}, te_params={}): rec_list = list(sorting_dict.keys()) @@ -16,16 +17,25 @@ def extract_templates_from_sorting_dict(sorting_dict, qc_params={}, te_params={} for sorting_path in sorting_list: sorting = si.KiloSortSortingExtractor(sorting_path) - stream_id = [p for p in sorting_path.split('/') if p.startswith('well')][0] #Find out which well this belongs to + stream_id = [p for p in sorting_path.split("/") if p.startswith("well")][ + 0 + ] # Find out which well this belongs to print(stream_id) - #rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id) - multirecording, common_el, pos = ss.concatenate_recording_slices(rec_path, stream_id) + # rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id) + multirecording, common_el, pos = ss.concatenate_recording_slices( + rec_path, stream_id + ) cleaned_sorting = select_good_units(sorting, **qc_params) - cleaned_sorting = si.remove_excess_spikes(cleaned_sorting, multirecording) #Relevant if last spike time == recording_length + cleaned_sorting = si.remove_excess_spikes( + cleaned_sorting, multirecording + ) # Relevant if last spike time == recording_length cleaned_sorting.register_recording(multirecording) segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording) - extract_all_templates(stream_id, segment_sorting, sorting_path, pos, te_params) - + extract_all_templates( + stream_id, segment_sorting, sorting_path, pos, te_params + ) + + def extract_templates_from_sorting_list(sorting_list, qc_params={}, te_params={}): """Performs template extraction from a list of sorting paths. Does not require a recording path. @@ -34,263 +44,388 @@ def extract_templates_from_sorting_list(sorting_list, qc_params={}, te_params={} qc_params (dict, optional): Dict of quality control parameters. Defaults to {}. te_params (dict, optional): Dict of template extraction parameters. Defaults to {}. """ - + for sorting_path in tqdm(sorting_list): try: - output_path = os.path.join(sorting_path,'sorter_output') - sorting = si.KiloSortSortingExtractor(output_path) - json_path = os.path.join(sorting_path, 'spikeinterface_recording.json') - multirecording = si.load_extractor(json_path, base_folder=True) - rec_path = ss.get_recording_path(multirecording) - stream_id = [p for p in sorting_path.split('/') if p.startswith('well')][0] #Find out which well this belongs to - - rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id) - cleaned_sorting = select_good_units(sorting, **qc_params) - cleaned_sorting = si.remove_excess_spikes(cleaned_sorting, multirecording) #Relevant if last spike time == recording_length - cleaned_sorting.register_recording(multirecording) - segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording) - extract_all_templates(stream_id, segment_sorting, output_path, pos, te_params) + output_path = os.path.join(sorting_path, "sorter_output") + sorting = si.KiloSortSortingExtractor(output_path) + json_path = os.path.join(sorting_path, "spikeinterface_recording.json") + multirecording = si.load_extractor(json_path, base_folder=True) + rec_path = ss.get_recording_path(multirecording) + stream_id = [p for p in sorting_path.split("/") if p.startswith("well")][ + 0 + ] # Find out which well this belongs to + + rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id) + cleaned_sorting = select_good_units(sorting, **qc_params) + cleaned_sorting = si.remove_excess_spikes( + cleaned_sorting, multirecording + ) # Relevant if last spike time == recording_length + cleaned_sorting.register_recording(multirecording) + segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording) + extract_all_templates( + stream_id, segment_sorting, output_path, pos, te_params + ) except Exception as e: print(e) continue - -def extract_templates_from_concatenated_recording(root_path, stream_id, qc_params={}, te_params={}): + +def extract_templates_from_concatenated_recording( + root_path, stream_id, qc_params={}, te_params={} +): sorting_path = os.path.join(root_path, stream_id) seg_sorting, concat_rec = ss.split_concatenated_sorting(sorting_path) - + # Split axon tracking - ax_sorting = si.select_segment_sorting(seg_sorting,0) + ax_sorting = si.select_segment_sorting(seg_sorting, 0) ax_rec_path = ss.get_recording_path(ax_sorting) - ax_recording, common_el, pos = ss.concatenate_recording_slices(ax_rec_path, stream_id, center=False) + ax_recording, common_el, pos = ss.concatenate_recording_slices( + ax_rec_path, stream_id, center=False + ) cleaned_sorting = select_good_units(seg_sorting, **qc_params) - ax_sorting = si.select_segment_sorting(cleaned_sorting,0) + ax_sorting = si.select_segment_sorting(cleaned_sorting, 0) ax_sorting = si.remove_excess_spikes(ax_sorting, ax_recording) ax_sorting.register_recording(ax_recording) ax_split_sorting = si.SplitSegmentSorting(ax_sorting, ax_recording) # Split network recordings - nw_sorting = si.select_segment_sorting(cleaned_sorting,1) - nw_recording = concat_rec._kwargs['recording_list'][1]._kwargs['recording']._kwargs['parent_recording'] + nw_sorting = si.select_segment_sorting(cleaned_sorting, 1) + nw_recording = ( + concat_rec._kwargs["recording_list"][1] + ._kwargs["recording"] + ._kwargs["parent_recording"] + ) nw_sorting = si.remove_excess_spikes(nw_sorting, nw_recording) nw_sorting.register_recording(nw_recording) nw_split_sorting = si.SplitSegmentSorting(nw_sorting, nw_recording) # Save split sortings ss.save_split_sorting(nw_split_sorting) - + # Extract templates extract_all_templates(stream_id, ax_split_sorting, sorting_path, pos, te_params) + def get_assay_information(rec_path): h5 = h5py.File(rec_path) pre, post, well_id = -1, -1, 0 - while pre <= 0 or post <= 0: #some failed axon trackings give negative trigger_post values, so we try different wells - well_name = list(h5['wells'].keys())[well_id] - rec_name = list(h5['wells'][well_name].keys())[well_id] - sampling_rate = h5['wells'][well_name][rec_name]['settings']['sampling'][0] + while ( + pre <= 0 or post <= 0 + ): # some failed axon trackings give negative trigger_post values, so we try different wells + well_name = list(h5["wells"].keys())[well_id] + rec_name = list(h5["wells"][well_name].keys())[well_id] + sampling_rate = h5["wells"][well_name][rec_name]["settings"]["sampling"][0] try: - pre = h5['wells'][well_name][rec_name]['groups']['routed']['trigger_pre'][0] - post = h5['wells'][well_name][rec_name]['groups']['routed']['trigger_post'][0] + pre = h5["wells"][well_name][rec_name]["groups"]["routed"]["trigger_pre"][0] + post = h5["wells"][well_name][rec_name]["groups"]["routed"]["trigger_post"][ + 0 + ] except: break well_id += 1 - + cutout_samples = [pre, post] - #Workaround to accomodate waveform extraction from network recordings - if cutout_samples[0] < 0: + # Workaround to accomodate waveform extraction from network recordings + if cutout_samples[0] < 0: print("Network recording detected, using default [1.5, 5]") cutout_ms = np.array([1.5, 5]) - cutout_samples = cutout_ms * (sampling_rate/1000) + cutout_samples = cutout_ms * (sampling_rate / 1000) else: - cutout_ms = [x / (sampling_rate/1000) for x in cutout_samples] #convert cutout to ms - + cutout_ms = [ + x / (sampling_rate / 1000) for x in cutout_samples + ] # convert cutout to ms + return cutout_samples, cutout_ms + def find_files(save_root, file_name="templates.npy", folder_name="sorter_output"): - file_list = [root - for root, dirs, files in os.walk(save_root) - for dir in dirs - if dir == folder_name and os.path.exists(os.path.join(root,folder_name,file_name))] - #file_list = [os.path.join(file, folder_name) for file in file_list] + file_list = [ + root + for root, dirs, files in os.walk(save_root) + for dir in dirs + if dir == folder_name + and os.path.exists(os.path.join(root, folder_name, file_name)) + ] + # file_list = [os.path.join(file, folder_name) for file in file_list] return file_list + def find_successful_sortings(path_list, save_path_changes): sorting_dict = dict() for rec_path in path_list: save_root = ss.convert_rec_path_to_save_path(rec_path, save_path_changes) - - #Takes into account different sorting folder names, subfolder depth, well IDs etc. - sorting_files = [root - for root, dirs, files in os.walk(save_root) - for name in files - if name == "templates.npy"] + + # Takes into account different sorting folder names, subfolder depth, well IDs etc. + sorting_files = [ + root + for root, dirs, files in os.walk(save_root) + for name in files + if name == "templates.npy" + ] sorting_dict[rec_path] = sorting_files - + return sorting_dict - + def postprocess_sorting(): - #Maybe we will do some postprocessing before we use them + # Maybe we will do some postprocessing before we use them return + def select_good_units(sorting, min_n_spikes=1500, exclude_mua=True, use_bc=False): if exclude_mua: - ks_idx = sorting.get_property('KSLabel') == 'good' + ks_idx = sorting.get_property("KSLabel") == "good" else: - ks_idx = np.full((sorting.get_num_units(),), True, dtype='bool') + ks_idx = np.full((sorting.get_num_units(),), True, dtype="bool") - if use_bc and len(sorting.get_property('bc_unitType')) > 0: - bc_idx = sorting.get_property('bc_unitType') == 'GOOD' + if use_bc and len(sorting.get_property("bc_unitType")) > 0: + bc_idx = sorting.get_property("bc_unitType") == "GOOD" else: - #print('No bombcell output found') - bc_idx = np.full((sorting.get_num_units(),), True, dtype='bool') - - n_spikes = [len(sorting.get_unit_spike_train(x,segment_index=0)) for x in sorting.get_unit_ids()] + # print('No bombcell output found') + bc_idx = np.full((sorting.get_num_units(),), True, dtype="bool") + + n_spikes = [ + len(sorting.get_unit_spike_train(x, segment_index=0)) + for x in sorting.get_unit_ids() + ] - good_n_spikes_idx = np.array(n_spikes) > min_n_spikes good_idx = ks_idx & bc_idx & good_n_spikes_idx good_ids = sorting.get_unit_ids()[good_idx] cleaned_sorting = sorting.select_units(good_ids) - - return cleaned_sorting + return cleaned_sorting -def extract_waveforms(segment_sorting, stream_id, save_root, n_jobs, overwrite_wf, cutout, filter_flag): +def extract_waveforms( + segment_sorting, stream_id, save_root, n_jobs, overwrite_wf, cutout, filter_flag +): full_path = ss.get_recording_path(segment_sorting) - + h5 = h5py.File(full_path) - rec_names = list(h5['wells'][stream_id].keys()) - + rec_names = list(h5["wells"][stream_id].keys()) + for sel_idx, rec_name in enumerate(rec_names): - wf_path = os.path.join(save_root, 'waveforms', 'seg' + str(sel_idx)) + wf_path = os.path.join(save_root, "waveforms", "seg" + str(sel_idx)) if not os.path.exists(wf_path) or overwrite_wf: - rec = si.MaxwellRecordingExtractor(full_path,stream_id=stream_id,rec_name=rec_name) - chunk_size = np.min([10000, rec.get_num_samples()]) - 100 #Fallback for ultra short recordings (too little activity) + rec = si.MaxwellRecordingExtractor( + full_path, stream_id=stream_id, rec_name=rec_name + ) + chunk_size = ( + np.min([10000, rec.get_num_samples()]) - 100 + ) # Fallback for ultra short recordings (too little activity) if filter_flag: rec_centered = si.bandpass_filter(rec, freq_min=300, freq_max=6000) else: rec_centered = si.center(rec, chunk_size=chunk_size) - + seg_sort = si.SelectSegmentSorting(segment_sorting, sel_idx) seg_sort = si.remove_excess_spikes(seg_sort, rec_centered) seg_sort.register_recording(rec_centered) - - seg_we = si.WaveformExtractor.create(rec_centered, seg_sort, - wf_path, - allow_unfiltered=True, - remove_if_exists=True) - seg_we.set_params(ms_before=cutout[0], ms_after=cutout[1], return_scaled = True, max_spikes_per_unit=10000) + + seg_we = si.WaveformExtractor.create( + rec_centered, + seg_sort, + wf_path, + allow_unfiltered=True, + remove_if_exists=True, + ) + seg_we.set_params( + ms_before=cutout[0], + ms_after=cutout[1], + return_scaled=True, + max_spikes_per_unit=10000, + ) seg_we.run_extract_waveforms(n_jobs=n_jobs, progress_bar=False) -def align_waveforms(seg_we, sel_unit_id, cutout, ms_peak_cutout, upsample, align_cutout, rm_outliers, n_jobs, n_neighbors): - + +def align_waveforms( + seg_we, + sel_unit_id, + cutout, + ms_peak_cutout, + upsample, + align_cutout, + rm_outliers, + n_jobs, + n_neighbors, +): + sample_peak_cutout = ms_peak_cutout * upsample peak_idx = cutout[0] * upsample - peak_cutout = range(np.int16(peak_idx - sample_peak_cutout), np.int16(peak_idx + sample_peak_cutout)) + peak_cutout = range( + np.int16(peak_idx - sample_peak_cutout), np.int16(peak_idx + sample_peak_cutout) + ) wfs = seg_we.get_waveforms(sel_unit_id) - interp_wfs = sp.interpolate.pchip_interpolate(list(range(wfs.shape[1])), wfs, np.linspace(0,wfs.shape[1], num = wfs.shape[1]*upsample), axis=1) - interp_wfs = interp_wfs - np.median(interp_wfs, axis=1)[:,np.newaxis,:] + interp_wfs = sp.interpolate.pchip_interpolate( + list(range(wfs.shape[1])), + wfs, + np.linspace(0, wfs.shape[1], num=wfs.shape[1] * upsample), + axis=1, + ) + interp_wfs = interp_wfs - np.median(interp_wfs, axis=1)[:, np.newaxis, :] if align_cutout: - peak_el = [np.where(interp_wfs[w,peak_cutout,:] == np.nanmin(interp_wfs[w,peak_cutout,:]))[1][0] for w in range(interp_wfs.shape[0])] + peak_el = [ + np.where( + interp_wfs[w, peak_cutout, :] + == np.nanmin(interp_wfs[w, peak_cutout, :]) + )[1][0] + for w in range(interp_wfs.shape[0]) + ] ref_el, count = sp.stats.mode(peak_el, keepdims=False) - peak_shift = [np.where(interp_wfs[w,peak_cutout,ref_el] == np.nanmin(interp_wfs[w,peak_cutout,ref_el]))[0][0] for w in range(interp_wfs.shape[0])] - aligned_length = interp_wfs.shape[1] - 2*sample_peak_cutout - aligned_wfs = np.full([interp_wfs.shape[0], np.int16(aligned_length), interp_wfs.shape[2]], np.nan) + peak_shift = [ + np.where( + interp_wfs[w, peak_cutout, ref_el] + == np.nanmin(interp_wfs[w, peak_cutout, ref_el]) + )[0][0] + for w in range(interp_wfs.shape[0]) + ] + aligned_length = interp_wfs.shape[1] - 2 * sample_peak_cutout + aligned_wfs = np.full( + [interp_wfs.shape[0], np.int16(aligned_length), interp_wfs.shape[2]], np.nan + ) for w in range(interp_wfs.shape[0]): - aligned_wfs[w,:,:] = interp_wfs[w,peak_shift[w]:np.int16(peak_shift[w]+aligned_length),:] + aligned_wfs[w, :, :] = interp_wfs[ + w, peak_shift[w] : np.int16(peak_shift[w] + aligned_length), : + ] else: aligned_wfs = interp_wfs - if rm_outliers: - peak_el = [np.where(interp_wfs[w,peak_cutout,:] == np.nanmin(interp_wfs[w,peak_cutout,:]))[1][0] for w in range(interp_wfs.shape[0])] + peak_el = [ + np.where( + interp_wfs[w, peak_cutout, :] + == np.nanmin(interp_wfs[w, peak_cutout, :]) + )[1][0] + for w in range(interp_wfs.shape[0]) + ] ref_el, count = sp.stats.mode(peak_el, keepdims=False) aligned_wfs = remove_wf_outliers(aligned_wfs, ref_el, n_jobs, n_neighbors) aligned_template = np.median(aligned_wfs, axis=0) - + return aligned_template def remove_wf_outliers(aligned_wfs, ref_el, n_jobs, n_neighbors): clf = sk.neighbors.LocalOutlierFactor(n_jobs=n_jobs, n_neighbors=n_neighbors) - outlier_idx = clf.fit_predict(aligned_wfs[:,:,ref_el]) - #print(f'Detected {sum(outlier_idx==-1)} outliers') - outlier_rm = np.delete(aligned_wfs, outlier_idx==-1, axis=0) - + outlier_idx = clf.fit_predict(aligned_wfs[:, :, ref_el]) + # print(f'Detected {sum(outlier_idx==-1)} outliers') + outlier_rm = np.delete(aligned_wfs, outlier_idx == -1, axis=0) + return outlier_rm -def combine_templates(stream_id, segment_sorting, sel_unit_id, save_root, peak_cutout=2, align_cutout=True, upsample=2, - rm_outliers=True, n_jobs=16, n_neighbors=10, overwrite_wf=False, overwrite_tmp = True, filter_flag = False): - + +def combine_templates( + stream_id, + segment_sorting, + sel_unit_id, + save_root, + peak_cutout=2, + align_cutout=True, + upsample=2, + rm_outliers=True, + n_jobs=16, + n_neighbors=10, + overwrite_wf=False, + overwrite_tmp=True, + filter_flag=False, +): + full_path = ss.get_recording_path(segment_sorting) cutout_samples, cutout_ms = get_assay_information(full_path) - #extract_waveforms(segment_sorting, stream_id, save_root, n_jobs, overwrite_wf, cutout_ms, filter_flag) - + # extract_waveforms(segment_sorting, stream_id, save_root, n_jobs, overwrite_wf, cutout_ms, filter_flag) + if align_cutout: - wf_length = np.int16((sum(cutout_samples) - 2*peak_cutout) * upsample) #length of waveforms after adjusting for potential peak alignments - + wf_length = np.int16( + (sum(cutout_samples) - 2 * peak_cutout) * upsample + ) # length of waveforms after adjusting for potential peak alignments + else: wf_length = np.int16(sum(cutout_samples) * upsample) - + template_matrix = np.full([wf_length, 26400], np.nan) - #noise_levels = np.full([1,26400], np.nan) - + # noise_levels = np.full([1,26400], np.nan) + h5 = h5py.File(full_path) - rec_names = list(h5['wells'][stream_id].keys()) - + rec_names = list(h5["wells"][stream_id].keys()) + for sel_idx, rec_name in enumerate(rec_names): - rec = si.MaxwellRecordingExtractor(full_path,stream_id=stream_id,rec_name=rec_name) + rec = si.MaxwellRecordingExtractor( + full_path, stream_id=stream_id, rec_name=rec_name + ) els = rec.get_property("contact_vector")["electrode"] seg_sort = si.SelectSegmentSorting(segment_sorting, sel_idx) - seg_we = si.load_waveforms(os.path.join(save_root, 'waveforms', 'seg' + str(sel_idx)), sorting = seg_sort) - aligned_wfs = align_waveforms(seg_we, sel_unit_id, cutout_samples, peak_cutout, upsample, align_cutout, rm_outliers, n_jobs, n_neighbors) - template_matrix[:,els] = aligned_wfs #find way to average common electrodes - #noise_levels[:,els] = si.compute_noise_levels(seg_we) - - return template_matrix#, noise_levels + seg_we = si.load_waveforms( + os.path.join(save_root, "waveforms", "seg" + str(sel_idx)), sorting=seg_sort + ) + aligned_wfs = align_waveforms( + seg_we, + sel_unit_id, + cutout_samples, + peak_cutout, + upsample, + align_cutout, + rm_outliers, + n_jobs, + n_neighbors, + ) + template_matrix[:, els] = aligned_wfs # find way to average common electrodes + # noise_levels[:,els] = si.compute_noise_levels(seg_we) + + return template_matrix # , noise_levels + def convert_to_grid(template_matrix, pos): - clean_template = np.delete(template_matrix, np.isnan(pos['x']), axis = 1) - clean_x = pos['x'][~np.isnan(pos['x'])] - clean_y = pos['y'][~np.isnan(pos['y'])] + clean_template = np.delete(template_matrix, np.isnan(pos["x"]), axis=1) + clean_x = pos["x"][~np.isnan(pos["x"])] + clean_y = pos["y"][~np.isnan(pos["y"])] x_idx = np.int16(clean_x / 17.5) y_idx = np.int16(clean_y / 17.5) - grid = np.full([np.max(x_idx) + 1, np.max(y_idx) + 1, clean_template.shape[0]],0).astype('float32') + grid = np.full( + [np.max(x_idx) + 1, np.max(y_idx) + 1, clean_template.shape[0]], 0 + ).astype("float32") for i in range(len(y_idx)): - grid[x_idx[i],y_idx[i],:] = clean_template[:,i] - + grid[x_idx[i], y_idx[i], :] = clean_template[:, i] + return grid def extract_all_templates(stream_id, segment_sorting, save_root, pos, te_params): sel_unit_ids = segment_sorting.get_unit_ids() - template_save_path = os.path.join(save_root, 'templates') + template_save_path = os.path.join(save_root, "templates") if not os.path.exists(template_save_path): os.makedirs(template_save_path) full_path = ss.get_recording_path(segment_sorting) cutout_samples, cutout_ms = get_assay_information(full_path) - extract_waveforms(segment_sorting, stream_id, save_root, te_params['n_jobs'], te_params['overwrite_wf'], cutout_ms, te_params['filter_flag']) - - for sel_unit_id in tqdm(sel_unit_ids): - template_save_file = os.path.join(template_save_path, str(sel_unit_id) + '.npy') - #noise_save_file = os.path.join(template_save_path, str(sel_unit_id) + '_noise.npy') - - if not os.path.isfile(template_save_file) or te_params['overwrite_tmp']: + extract_waveforms( + segment_sorting, + stream_id, + save_root, + te_params["n_jobs"], + te_params["overwrite_wf"], + cutout_ms, + te_params["filter_flag"], + ) + + for sel_unit_id in tqdm(sel_unit_ids): + template_save_file = os.path.join(template_save_path, str(sel_unit_id) + ".npy") + # noise_save_file = os.path.join(template_save_path, str(sel_unit_id) + '_noise.npy') + + if not os.path.isfile(template_save_file) or te_params["overwrite_tmp"]: try: - template_matrix = combine_templates(stream_id, segment_sorting, sel_unit_id, save_root, **te_params) + template_matrix = combine_templates( + stream_id, segment_sorting, sel_unit_id, save_root, **te_params + ) grid = convert_to_grid(template_matrix, pos) np.save(template_save_file, grid) - #noise_levels = convert_to_grid(noise_levels, pos) - #np.save(noise_save_file, noise_levels) - - except Exception as e: - print(f'Unit {sel_unit_id} encountered the following error:\n {e}') + # noise_levels = convert_to_grid(noise_levels, pos) + # np.save(noise_save_file, noise_levels) + except Exception as e: + print(f"Unit {sel_unit_id} encountered the following error:\n {e}") diff --git a/axon_tracking/visualization.py b/axon_tracking/visualization.py index c2dc54e..2cf2da4 100644 --- a/axon_tracking/visualization.py +++ b/axon_tracking/visualization.py @@ -8,259 +8,369 @@ from skimage.morphology import ball from matplotlib.collections import LineCollection -def plot_velocity_qc(vels,r2s,fig_size=(6,2)): - fig, axes = plt.subplots(1,2,figsize=fig_size) + +def plot_velocity_qc(vels, r2s, fig_size=(6, 2)): + fig, axes = plt.subplots(1, 2, figsize=fig_size) plt.subplot(121) - plt.hist(vels,100,range=(0,2)) - plt.title('Velocity') - + plt.hist(vels, 100, range=(0, 2)) + plt.title("Velocity") + plt.subplot(122) - plt.hist(r2s,100,range=(0,1)) - plt.title('R2') + plt.hist(r2s, 100, range=(0, 1)) + plt.title("R2") plt.show() - -def plot_bare_skeleton(path_list, params, save_path = [], figsize=4, linewidth=2, cmap='copper'): - sorted_vertices = skel.path_to_vertices(path_list,params) - c_max = 1000 * (sorted_vertices[-1][2]) / params['sampling_rate'] - fig, ax = plt.subplots(figsize=(22*figsize,12*figsize),constrained_layout=True) + +def plot_bare_skeleton( + path_list, params, save_path=[], figsize=4, linewidth=2, cmap="copper" +): + sorted_vertices = skel.path_to_vertices(path_list, params) + c_max = 1000 * (sorted_vertices[-1][2]) / params["sampling_rate"] + + fig, ax = plt.subplots( + figsize=(22 * figsize, 12 * figsize), constrained_layout=True + ) for path in path_list: - x = path[:,0] - y = path[:,1] - cols = (path[:,2])/params['sampling_rate'] * 1000 + x = path[:, 0] + y = path[:, 1] + cols = (path[:, 2]) / params["sampling_rate"] * 1000 points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) - lc = LineCollection(segments, cmap=cmap,linewidths=linewidth) + lc = LineCollection(segments, cmap=cmap, linewidths=linewidth) lc.set_array(cols) lc.set_clim((0, np.ceil(c_max))) - + line = ax.add_collection(lc) ax.set_xlim([0, 440]) - ax.set_ylim([240,0]) - ax.axis('off') + ax.set_ylim([240, 0]) + ax.axis("off") if save_path: - plt.savefig(save_path,dpi=300, transparent=True) + plt.savefig(save_path, dpi=300, transparent=True) plt.close() - #plt.show() + # plt.show() return fig, ax -def plot_delay_skeleton(path_list, params, skel_params,figsize=4, plot_ais=True, plot_ais_connection=True, linewidth=2,font_size=24): - - #path_list = skel.scale_path_coordinates(path_list) - sorted_vertices = skel.path_to_vertices(path_list,params) - c_max = 1000 * (sorted_vertices[-1][2]) / params['sampling_rate'] - fig_ratio = np.ptp(sorted_vertices,axis=0)[0] / np.ptp(sorted_vertices,axis=0)[1] + 0.7 - fig, ax = plt.subplots(figsize=(22*figsize,12*figsize),constrained_layout=True) +def plot_delay_skeleton( + path_list, + params, + skel_params, + figsize=4, + plot_ais=True, + plot_ais_connection=True, + linewidth=2, + font_size=24, +): + + # path_list = skel.scale_path_coordinates(path_list) + sorted_vertices = skel.path_to_vertices(path_list, params) + c_max = 1000 * (sorted_vertices[-1][2]) / params["sampling_rate"] + fig_ratio = ( + np.ptp(sorted_vertices, axis=0)[0] / np.ptp(sorted_vertices, axis=0)[1] + 0.7 + ) + + fig, ax = plt.subplots( + figsize=(22 * figsize, 12 * figsize), constrained_layout=True + ) for path in path_list: - x = path[:,0] - y = path[:,1] - cols = (path[:,2])/params['sampling_rate'] * 1000 + x = path[:, 0] + y = path[:, 1] + cols = (path[:, 2]) / params["sampling_rate"] * 1000 points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) - lc = LineCollection(segments, cmap='copper',linewidths=linewidth) + lc = LineCollection(segments, cmap="copper", linewidths=linewidth) lc.set_array(cols) lc.set_clim((0, np.ceil(c_max))) - + line = ax.add_collection(lc) - - clb = fig.colorbar(line,ax=ax,ticks=[0, np.ceil(c_max)],shrink= 0.3)#,format=mticker.FixedFormatter(clb_ticks) - clb.set_label(label="Delay (ms)",size=font_size*figsize) - clb.ax.tick_params(labelsize=font_size*figsize,length=0) + + clb = fig.colorbar( + line, ax=ax, ticks=[0, np.ceil(c_max)], shrink=0.3 + ) # ,format=mticker.FixedFormatter(clb_ticks) + clb.set_label(label="Delay (ms)", size=font_size * figsize) + clb.ax.tick_params(labelsize=font_size * figsize, length=0) if plot_ais: - plt.scatter(skel_params['ais'][0][0],skel_params['ais'][0][1],s=50,color='k',zorder=10) + plt.scatter( + skel_params["ais"][0][0], + skel_params["ais"][0][1], + s=50, + color="k", + zorder=10, + ) if plot_ais_connection: - plt.plot([skel_params['ais'][0][0], sorted_vertices[0][0]], [skel_params['ais'][0][1], sorted_vertices[0][1]]) - #print(sorted_vertices[0][0]) - + plt.plot( + [skel_params["ais"][0][0], sorted_vertices[0][0]], + [skel_params["ais"][0][1], sorted_vertices[0][1]], + ) + # print(sorted_vertices[0][0]) + ax.autoscale_view() ax.set_ylim(ax.get_ylim()[::-1]) - #ax.set_xlim([0, 2200]) - #ax.set_xlabel("(μm)") - #ax.set_ylabel("(μm)") - #plt.show() + # ax.set_xlim([0, 2200]) + # ax.set_xlabel("(μm)") + # ax.set_ylabel("(μm)") + # plt.show() return fig, ax - -def plot_conduction_velocity(path_list,params,fig_size=(3,3)): - + + +def plot_conduction_velocity(path_list, params, fig_size=(3, 3)): + fig, axes = plt.subplots(figsize=fig_size) for path in path_list: - path_diff = np.diff(path,axis=0) - dist = np.sqrt(path_diff[:,0]**2 + path_diff[:,1]**2) / 1000 - time = (np.cumsum(np.abs(path_diff[:,2])) / params['sampling_rate']) * 1000 - plt.scatter(time,np.cumsum(dist),s=1) - plt.ylabel('Distance (mm)') - plt.xlabel('Time (ms)') + path_diff = np.diff(path, axis=0) + dist = np.sqrt(path_diff[:, 0] ** 2 + path_diff[:, 1] ** 2) / 1000 + time = (np.cumsum(np.abs(path_diff[:, 2])) / params["sampling_rate"]) * 1000 + plt.scatter(time, np.cumsum(dist), s=1) + plt.ylabel("Distance (mm)") + plt.xlabel("Time (ms)") plt.show() - -def plot_template_and_noise(template,noise,th_template): - fig, axes = plt.subplots(1,3,figsize=(21,7)) +def plot_template_and_noise(template, noise, th_template): + fig, axes = plt.subplots(1, 3, figsize=(21, 7)) plt.subplot(131) - plt.imshow(np.min(template,axis=2).T,vmin=-10,vmax=0) + plt.imshow(np.min(template, axis=2).T, vmin=-10, vmax=0) plt.colorbar(shrink=0.3) - + plt.subplot(132) plt.imshow(np.squeeze(noise).T) plt.colorbar(shrink=0.3) - + plt.subplot(133) - plt.imshow(np.sum(th_template,axis=2).T, vmin=0, vmax=20) + plt.imshow(np.sum(th_template, axis=2).T, vmin=0, vmax=20) plt.colorbar(shrink=0.3) - + plt.show() -def plot_template_overview(root_path, stream_id, params, n_cols = 3, vmin=-10,vmax=0, filename='overview', unit_ids=[], overwrite=False): - full_filename = os.path.join(root_path,stream_id,filename +'.png') + +def plot_template_overview( + root_path, + stream_id, + params, + n_cols=3, + vmin=-10, + vmax=0, + filename="overview", + unit_ids=[], + overwrite=False, +): + full_filename = os.path.join(root_path, stream_id, filename + ".png") if os.path.exists(full_filename) and not overwrite: from IPython.display import display, Image + display(Image(filename=full_filename)) else: - parent_dir = os.path.join(root_path,stream_id,'sorter_output/templates') + parent_dir = os.path.join(root_path, stream_id, "sorter_output/templates") files = os.listdir(parent_dir) - template_files = [f for f in files if '_' not in f] - ids = [float(t.split('.')[0]) for t in template_files] + template_files = [f for f in files if "_" not in f] + ids = [float(t.split(".")[0]) for t in template_files] if len(unit_ids) > 0: sort_idx = [ids.index(x) for x in unit_ids] else: sort_idx = np.argsort(ids) template_files = [template_files[i] for i in sort_idx] - + n_rows = int(np.ceil(len(template_files) / n_cols)) - fig, axes = plt.subplots(n_rows,n_cols,figsize=(14,3*n_rows)) + fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 3 * n_rows)) for i, template_file in enumerate(template_files): - template_path = os.path.join(root_path,stream_id,'sorter_output/templates',template_file) + template_path = os.path.join( + root_path, stream_id, "sorter_output/templates", template_file + ) template = np.load(template_path) temp_diff = np.diff(template) - - capped_template, target_coor = skel.localize_neurons(temp_diff, ms_cutout=params['ms_cutout']) - tmp_filt = nd.median_filter(capped_template,footprint=params['filter_footprint']) - plt.subplot(n_rows, n_cols, i+1) - plt.imshow(np.min(tmp_filt,axis=2).T,vmin=vmin,vmax=vmax) + + capped_template, target_coor = skel.localize_neurons( + temp_diff, ms_cutout=params["ms_cutout"] + ) + tmp_filt = nd.median_filter( + capped_template, footprint=params["filter_footprint"] + ) + plt.subplot(n_rows, n_cols, i + 1) + plt.imshow(np.min(tmp_filt, axis=2).T, vmin=vmin, vmax=vmax) plt.title(template_file) - + plt.savefig(full_filename) - + def plot_waveforms(template): - flat_tmp = np.reshape(template.transpose((2,0,1)),[template.shape[2],-1]) + flat_tmp = np.reshape(template.transpose((2, 0, 1)), [template.shape[2], -1]) plt.plot(flat_tmp) plt.show() -def plot_skeleton(skeleton, x_lim=[], y_lim=[], fig_size=5, marker_size=10,ais=[]): - x,y,z = [skeleton.vertices[:,x] for x in range(3)] - x_scaling = np.abs(((np.max(x) - np.min(x)) / (np.max(y) - np.min(y))) * 1.15)#[0] - clb_ticks = ['0',str(np.round(max(z)/20,decimals=1))] - - fig, axes = plt.subplots(figsize=(fig_size*x_scaling,fig_size)) - - im = plt.scatter(x,y,c=z,s=marker_size,cmap="coolwarm",marker='o',vmax=1) - - clb = plt.colorbar(ticks=[0, max(z)],format=mticker.FixedFormatter(clb_ticks),shrink= 0.5) - clb.set_label(label="Latency (ms)",size=6) - clb.ax.tick_params(labelsize=6,length=0) - #axes.set_axis_off() - if len(x_lim) > 1& len(y_lim) > 1: + +def plot_skeleton(skeleton, x_lim=[], y_lim=[], fig_size=5, marker_size=10, ais=[]): + x, y, z = [skeleton.vertices[:, x] for x in range(3)] + x_scaling = np.abs( + ((np.max(x) - np.min(x)) / (np.max(y) - np.min(y))) * 1.15 + ) # [0] + clb_ticks = ["0", str(np.round(max(z) / 20, decimals=1))] + + fig, axes = plt.subplots(figsize=(fig_size * x_scaling, fig_size)) + + im = plt.scatter(x, y, c=z, s=marker_size, cmap="coolwarm", marker="o", vmax=1) + + clb = plt.colorbar( + ticks=[0, max(z)], format=mticker.FixedFormatter(clb_ticks), shrink=0.5 + ) + clb.set_label(label="Latency (ms)", size=6) + clb.ax.tick_params(labelsize=6, length=0) + # axes.set_axis_off() + if len(x_lim) > 1 & len(y_lim) > 1: plt.xlim(x_lim) plt.ylim(y_lim) else: axes.autoscale_view() axes.set_ylim(axes.get_ylim()[::-1]) - + plt.clim([0, max(z)]) - if len(ais)>0: - plt.scatter(ais[0],ais[1],s=marker_size*20,color='k') + if len(ais) > 0: + plt.scatter(ais[0], ais[1], s=marker_size * 20, color="k") plt.show() - return x,y,z + return x, y, z -def generate_propagation_gif(template, params, cumulative=True, vertices = [], downsample=2, clim=[-10, 0], cmap="Greys", spacing=1, marker_size=10): - el_offset = params['el_spacing'] - interp_offset = el_offset*spacing - xticks = np.arange(0,3850,500) - yticks = np.arange(0,2200,500) - conv_xticks = xticks/interp_offset - conv_yticks = yticks/interp_offset +def generate_propagation_gif( + template, + params, + cumulative=True, + vertices=[], + downsample=2, + clim=[-10, 0], + cmap="Greys", + spacing=1, + marker_size=10, +): + el_offset = params["el_spacing"] + interp_offset = el_offset * spacing + xticks = np.arange(0, 3850, 500) + yticks = np.arange(0, 2200, 500) + conv_xticks = xticks / interp_offset + conv_yticks = yticks / interp_offset if len(vertices) > 0: - x,y,z = [vertices[:,x] for x in range(3)] + x, y, z = [vertices[:, x] for x in range(3)] - clb_ticks = ['0',str(np.round(max(z)/(params['sampling_rate']/1000),decimals=1))] + clb_ticks = [ + "0", + str(np.round(max(z) / (params["sampling_rate"] / 1000), decimals=1)), + ] ims = [] - - fig = plt.figure() #added - for i in range(1,template.shape[2],downsample): + + fig = plt.figure() # added + for i in range(1, template.shape[2], downsample): ax = plt.axes() if cumulative: - plt_data = np.min(template[:,:,0:i].T, axis=0) + plt_data = np.min(template[:, :, 0:i].T, axis=0) else: - plt_data = template[:,:,i].T - #im = ax.imshow(plt_data, animated=True,vmin=clim[0],vmax=clim[1],cmap=cmap) - ax.imshow(plt_data, animated=True,vmin=clim[0],vmax=clim[1],cmap=cmap) #added + plt_data = template[:, :, i].T + # im = ax.imshow(plt_data, animated=True,vmin=clim[0],vmax=clim[1],cmap=cmap) + ax.imshow( + plt_data, animated=True, vmin=clim[0], vmax=clim[1], cmap=cmap + ) # added if len(vertices) > 0: - xi,yi,zi = x[z<=i], y[z<=i], z[z<=i] - scat = ax.scatter(xi,yi,c=zi,s=marker_size,cmap="coolwarm",vmin=np.min(z),vmax=np.max(z))#,alpha=0.5) - #if i == 1: + xi, yi, zi = x[z <= i], y[z <= i], z[z <= i] + scat = ax.scatter( + xi, + yi, + c=zi, + s=marker_size, + cmap="coolwarm", + vmin=np.min(z), + vmax=np.max(z), + ) # ,alpha=0.5) + # if i == 1: # clb = plt.colorbar(scat, ticks=[min(z), max(z)],format=mticker.FixedFormatter(clb_ticks),shrink= 0.5) # clb.set_label(label="Latency (ms)",size=6) # clb.ax.tick_params(labelsize=6,length=0) ax.set_xticks(conv_xticks) ax.set_xticklabels(xticks) - ax.set_xlabel(u"\u03bcm") - + ax.set_xlabel("\u03bcm") + ax.set_yticks(conv_yticks) ax.set_yticklabels(yticks) - ax.set_ylabel(u"\u03bcm") + ax.set_ylabel("\u03bcm") ims.append([ax]) - ani = anim.ArtistAnimation(fig, ims,interval=100) + ani = anim.ArtistAnimation(fig, ims, interval=100) return ani -def plot_filled_contour(capped_template,skeleton,params,radius=5,save_path=[], fig_size = 1, font_size=24): - interp_tmp = skel.interpolate_template(capped_template,spacing=params['upsample']) - interp_tmp = nd.gaussian_filter(interp_tmp,sigma=0.8) - sorted_vertices = skel.path_to_vertices(skeleton.paths(),params) + +def plot_filled_contour( + capped_template, skeleton, params, radius=5, save_path=[], fig_size=1, font_size=24 +): + interp_tmp = skel.interpolate_template(capped_template, spacing=params["upsample"]) + interp_tmp = nd.gaussian_filter(interp_tmp, sigma=0.8) + sorted_vertices = skel.path_to_vertices(skeleton.paths(), params) skel_mat = np.zeros(interp_tmp.shape) - skel_mat[tuple(sorted_vertices.astype('int').T)] = True - dil_mat = nd.binary_dilation(skel_mat,structure=ball(radius)) - th_data = (interp_tmp * dil_mat)#[:,:,t_cap[0]:t_cap[1]] - contour_data = np.abs(np.min(th_data,axis=2).T) - contourf_lines = np.append(np.floor(-np.max(contour_data)),np.linspace(-5, -0.1,15)) - #contourf_lines = np.append(np.geomspace(-np.max(contour_data),-3,10),np.linspace(-3, -0.1,15)) - fig, ax = plt.subplots(figsize=(22*fig_size,12*fig_size),constrained_layout=True) - plt.contourf(-contour_data,levels = contourf_lines,cmap="inferno",vmin=-5,vmax=-0.1)# ,linewidths = 0.2,vmax=20,vmin=2)hatches =[':'], - clb = plt.colorbar(ticks=[-np.max(contour_data), -0.1],format=mticker.FixedFormatter(['-100','-2']),shrink= 0.3) - clb.set_label(label=u"\u03bcV/ms",size=font_size) - clb.ax.tick_params(labelsize=font_size,length=0) + skel_mat[tuple(sorted_vertices.astype("int").T)] = True + dil_mat = nd.binary_dilation(skel_mat, structure=ball(radius)) + th_data = interp_tmp * dil_mat # [:,:,t_cap[0]:t_cap[1]] + contour_data = np.abs(np.min(th_data, axis=2).T) + contourf_lines = np.append( + np.floor(-np.max(contour_data)), np.linspace(-5, -0.1, 15) + ) + # contourf_lines = np.append(np.geomspace(-np.max(contour_data),-3,10),np.linspace(-3, -0.1,15)) + fig, ax = plt.subplots( + figsize=(22 * fig_size, 12 * fig_size), constrained_layout=True + ) + plt.contourf( + -contour_data, levels=contourf_lines, cmap="inferno", vmin=-5, vmax=-0.1 + ) # ,linewidths = 0.2,vmax=20,vmin=2)hatches =[':'], + clb = plt.colorbar( + ticks=[-np.max(contour_data), -0.1], + format=mticker.FixedFormatter(["-100", "-2"]), + shrink=0.3, + ) + clb.set_label(label="\u03bcV/ms", size=font_size) + clb.ax.tick_params(labelsize=font_size, length=0) ax.autoscale_view() ax.set_ylim(ax.get_ylim()[::-1]) - ax.axis('off') + ax.axis("off") if save_path: - plt.savefig(save_path,dpi=300, transparent=True) + plt.savefig(save_path, dpi=300, transparent=True) plt.close() else: plt.show() -def plot_delay_contour(capped_template,skeleton,params,skel_params, radius=5,save_path=[]): - interp_tmp = skel.interpolate_template(capped_template,spacing=params['upsample']) - interp_tmp = nd.gaussian_filter(interp_tmp,sigma=0.8) - sorted_vertices = skel.path_to_vertices(skeleton.paths(),params) - skel_mat = np.zeros(interp_tmp.shape) - skel_mat[tuple(sorted_vertices.astype('int').T)] = True - dil_mat = nd.binary_dilation(skel_mat,structure=ball(radius)) - th_data = (interp_tmp * dil_mat)#[:,:,t_cap[0]:t_cap[1]] - contour_data = np.abs(np.min(th_data,axis=2).T) - contour_lines = np.append(np.linspace(0.1,2,15),np.linspace(2.5,np.max(contour_data),20)) - #contour_lines = np.geomspace(0.1,np.max(contour_data),15) - - fig, ax = plot_delay_skeleton(skel.unscale_path_coordinates(skeleton.paths(),params), params, skel_params,figsize=1, plot_ais=False, plot_ais_connection=False, linewidth=4) - plt.contour(contour_data,levels = contour_lines,colors='k',linewidths = 0.2,alpha=0.8,zorder=0)#,vmax=20,vmin=2)hatches =[':'], - ax.autoscale_view() - ax.set_ylim([0,120]) - ax.set_ylim(ax.get_ylim()[::-1]) - ax.axis('off') - if save_path: - plt.savefig(save_path,dpi=300, transparent=True) - plt.close() - else: - plt.show() \ No newline at end of file + +def plot_delay_contour( + capped_template, skeleton, params, skel_params, radius=5, save_path=[] +): + interp_tmp = skel.interpolate_template(capped_template, spacing=params["upsample"]) + interp_tmp = nd.gaussian_filter(interp_tmp, sigma=0.8) + sorted_vertices = skel.path_to_vertices(skeleton.paths(), params) + skel_mat = np.zeros(interp_tmp.shape) + skel_mat[tuple(sorted_vertices.astype("int").T)] = True + dil_mat = nd.binary_dilation(skel_mat, structure=ball(radius)) + th_data = interp_tmp * dil_mat # [:,:,t_cap[0]:t_cap[1]] + contour_data = np.abs(np.min(th_data, axis=2).T) + contour_lines = np.append( + np.linspace(0.1, 2, 15), np.linspace(2.5, np.max(contour_data), 20) + ) + # contour_lines = np.geomspace(0.1,np.max(contour_data),15) + + fig, ax = plot_delay_skeleton( + skel.unscale_path_coordinates(skeleton.paths(), params), + params, + skel_params, + figsize=1, + plot_ais=False, + plot_ais_connection=False, + linewidth=4, + ) + plt.contour( + contour_data, + levels=contour_lines, + colors="k", + linewidths=0.2, + alpha=0.8, + zorder=0, + ) # ,vmax=20,vmin=2)hatches =[':'], + ax.autoscale_view() + ax.set_ylim([0, 120]) + ax.set_ylim(ax.get_ylim()[::-1]) + ax.axis("off") + if save_path: + plt.savefig(save_path, dpi=300, transparent=True) + plt.close() + else: + plt.show()