From 530e3affa33fa87b544c8c3f4ae0f8628a4975c2 Mon Sep 17 00:00:00 2001 From: Philipp Hornauer Date: Mon, 2 Dec 2024 18:06:02 +0100 Subject: [PATCH] Updated to the new spikeinterface version 0.101 --- axon_tracking/skeletonization.py | 375 ++++++++++++------- axon_tracking/spike_sorting.py | 16 +- axon_tracking/template_extraction.py | 14 +- axon_tracking/utils.py | 104 +++++- axon_tracking/visualization.py | 48 +-- notebooks/part2_template-extraction.ipynb | 2 +- notebooks/part3_skeletonization.ipynb | 337 +++++++++++++---- notebooks/test_sorting_si101.ipynb | 431 +--------------------- 8 files changed, 672 insertions(+), 655 deletions(-) diff --git a/axon_tracking/skeletonization.py b/axon_tracking/skeletonization.py index 6918e4b..294cdfd 100755 --- a/axon_tracking/skeletonization.py +++ b/axon_tracking/skeletonization.py @@ -1,20 +1,22 @@ -import os, kimimaro, sklearn -import numpy as np +import os + import cloudvolume as cv +import kimimaro +import numpy as np import scipy.ndimage as nd -from skimage.morphology import disk, ball -from skimage.feature import peak_local_max -from scipy.interpolate import RegularGridInterpolator import scipy.stats as stats -from scipy.spatial.distance import pdist, cdist +from scipy.interpolate import RegularGridInterpolator +from scipy.spatial.distance import pdist +from sklearn.linear_model import LinearRegression +from sklearn.metrics import r2_score +from skimage.feature import peak_local_max +from skimage.morphology import ball +from axon_tracking import utils as ut -def full_skeletonization( - root_path, stream_id, template_id, params, skel_params, qc_params -): - template, template_save_file, noise = load_template_file( - root_path, stream_id, template_id - ) + +def full_skeletonization(root_path, template_id, params, skel_params, qc_params): + template, template_save_file, noise = load_template_file(root_path, template_id) if np.mean(noise) > params["max_noise_level"]: return [], [] @@ -55,85 +57,226 @@ def full_skeletonization( return qc_skeleton, scaled_qc_list -def load_template_file(root_path, stream_id, template_id): - template_save_file = os.path.join( - root_path, stream_id, "sorter_output", "templates", str(template_id) + ".npy" - ) #'sorter_output', - noise_save_file = os.path.join( - root_path, stream_id, "templates", str(template_id) + "_noise.npy" - ) +def load_template_file(root_path, template_id): + """ + Load a template file. + This function loads a template file from the specified root path and template ID. + The template file is expected to be in NumPy (.npy) format. + Args: + root_path (str): The root directory path where the template file is located. + template_id (int): The identifier for the template file. The file is expected + to be named as '.npy'. + Returns: + tuple: A tuple containing: + - template (numpy.ndarray): The loaded template as a NumPy array of type float64. + - template_save_file (str): The full path to the loaded template file. + """ + + template_save_file = os.path.join(root_path, str(template_id) + ".npy") template = np.load(template_save_file).astype("float64") - if False: # os.path.exists(noise_save_file): - noise = np.load(noise_save_file).astype("float64") - else: - # print('No noise file found, inferring from template') - noise = generate_noise_matrix(template) - return template, template_save_file, noise + return template, template_save_file -def localize_neurons( - input_mat, - ms_cutout, - min_distance=5, - threshold_rel=0.1, - num_peaks=3, - buffer_frames=2, - ms_peak_cutout=0.5, -): - - # local_max = peak_local_max(np.abs(input_mat), min_distance=min_distance, threshold_rel=threshold_rel, num_peaks=num_peaks) +def preprocess_template(template, params): + """ + Preprocess a template by localizing the AIS and thresholding the template. + This function preprocesses a given template by localizing the AIS and thresholding + the template based on the absolute signal amplitude, the maximum velocity, and the + noise of the electrode. + Args: + template (numpy.ndarray): The template as a NumPy array. + params (dict): A dictionary containing the parameters for preprocessing. + Returns: + numpy.ndarray: The interpolated template. + numpy.ndarray: The noise matrix. + numpy.ndarray: The thresholded (boolean) template. + """ + # Calculate the noise matrix based on the template + noise = generate_noise_matrix(template, mode="mad") + # Large median filter to remove noise + med_filt = nd.median_filter(template, footprint=ball(2)) + # First derivative to eliminate slow drift + temp_diff = np.diff(med_filt) + # Median filter to remove noise from the derivative + tmp_filt = nd.median_filter(temp_diff, footprint=ball(1)) + # Localize neurons based on the derivative + if params["ais_detection"] is not None: + tmp_filt, ais = localize_ais(tmp_filt, params) + else: + ais = np.array([]) - # cutout_ratio = (ms_cutout[0]/np.sum(ms_cutout)) - # peak_range = [(cutout_ratio-(ms_peak_cutout/np.sum(ms_cutout))),(cutout_ratio+(ms_peak_cutout/np.sum(ms_cutout)))] - # peak_range = np.round(np.array(peak_range)*input_mat.shape[2]) + # Interpolate the template in x, y, and z (time) + interp_temp = interpolate_template(tmp_filt, spacing=params["upsample"]) - # target_coor = local_max[(local_max[:,2] >= peak_range[0]) & (local_max[:,2] <= peak_range[1]),:].astype("int16") + # Generate a noise matrix based on the template + interp_noise = interpolate_template(noise, spacing=params["upsample"]) - # if len(target_coor) > 0: - target_coor = list(np.unravel_index(np.argmax(-input_mat), input_mat.shape)) - capped_matrix = input_mat[:, :, (target_coor[2] - buffer_frames) :] - target_coor[2] = buffer_frames - # else: - # capped_matrix = input_mat - # target_coor=[[0, 0, 0]] + return interp_temp, interp_noise, ais - # post_coor = local_max[local_max[:,2] > peak_range[1],:].astype("int16") - # if len(post_coor) > 0: #Check if postsynaptic target was detected - # post_coor[0][2] = post_coor[0][2] - target_coor[0][2] - return capped_matrix, target_coor # , post_coor +def localize_ais( + input_mat, + params, + min_distance=5, + threshold_rel=0.1, +): + """Localize the AIS in a given template. + This function localizes AIS in a given input matrix using the `peak_local_max` + function from the `skimage` library to detect the peaks. + Args: + input_mat (numpy.ndarray): The input matrix as a NumPy array. + params (dict): A dictionary containing the parameters for localization. + min_distance (int): The minimum distance between peaks. + threshold_rel (float): The relative threshold for peaks. + Returns: + numpy.ndarray: The capped matrix containing the localized neurons. + numpy.ndarray: The coordinates of the AIS peak. + """ + peak_idx = params["ms_cutout"][0] * params["sampling_rate"] / 1000 + peak_cutout = ( + params["upsample"][2] * params["sampling_rate"] / 1000 + ) # Cutout in ms before and after the peak + local_max = peak_local_max( + np.abs(input_mat), min_distance=5, threshold_rel=0.1, num_peaks=10 + ) + if ( + params["ais_detection"] == "dev" + ): # Search for the minimum deviation from expected peak time + ais = local_max[np.argmin(local_max[:, 2] - peak_idx), :] + elif params["ais_detection"] == "time": # Search for first peak + ais = local_max[np.argmin(local_max[:, 2]), :] + elif params["ais_detection"] == "amp": # Search for highest peak + ais = local_max[0, :] + else: + raise ValueError("Invalid search mode") + + # Cap the matrix at the AIS peak time and add a buffer + ais_peak = np.min([ais[2], params["buffer_frames"]]) # Prevent negative indices + capped_matrix = input_mat[:, :, ais_peak:] + ais[2] = ais_peak + return capped_matrix, ais + + +def generate_noise_matrix(template, mode="mad"): + """Generate a noise matrix based on the template. + This function generates a noise matrix based on the template. The noise matrix is + calculated as the median absolute deviation (MAD) or standard deviation (SD) of the + template along the third axis. + Args: + template (numpy.ndarray): The template as a NumPy array + noise (numpy.ndarray): The noise matrix to use. If not provided, the noise matrix + is generated based on the template. + mode (str): The mode to use for noise calculation. Can be either 'mad' for median + absolute deviation or 'sd' for standard deviation. + Returns: + numpy.ndarray: The generated noise matrix. + """ -def generate_noise_matrix(template, noise=[], mode="mad"): - if not noise: - if mode == "mad": - noise = stats.median_abs_deviation(template, axis=2) - elif mode == "sd": - noise = np.std(template, axis=2) + if mode == "mad": + noise = stats.median_abs_deviation(template, axis=2) + elif mode == "sd": + noise = np.std(template, axis=2) - noise_matrix = noise[:, :, np.newaxis] + noise_matrix = noise[:, :, np.newaxis] # For compatibility with the template shape return noise_matrix def threshold_template(template, noise, target_coor, params): - if params["noise_threshold"]: + """Generate a thresholded template based on the absolute signal amplitude, the + maximum velocity, and the noise of the electrode (taken from noise matrix). + Args: + template (numpy.ndarray): The template as a NumPy array. + noise (numpy.ndarray): The noise matrix as a NumPy array. + target_coor (list): The target coordinates for the template. + params (dict): A dictionary containing the parameters for thresholding. + Returns: + numpy.ndarray: The boolean thresholded template. + """ + + # mad_noise = generate_noise_matrix(template, mode="mad") + # sd_noise = generate_noise_matrix(template, mode="sd") + if params["noise_threshold"] is not None: noise_th = template < ( params["noise_threshold"] * noise[:, :, : template.shape[2]] ) + print("Noise thresholding") else: noise_th = np.full_like(template, True) abs_th = template < params["abs_threshold"] - - # r = int((((template.shape[2] / params['sampling_rate']) * params['max_velocity']) * 1000000) / 17.5) - # velocity_th = cone(template.shape, r, apex=tuple(target_coor[0])) velocity_th = valid_latency_map(template, target_coor, params) - th_template = noise_th * abs_th * velocity_th + th_template = noise_th * abs_th * velocity_th # * ((sd_noise / mad_noise) > 1) return th_template +def valid_latency_map(template, start, params): + """Generate a boolean matrix indicating the valid latency map based on the maximum + velocity. + Args: + template (numpy.ndarray): The template as a NumPy array. + start (list): The starting coordinates (AIS) for velocity calculations. + params (dict): A dictionary containing the parameters for thresholding. + Returns: + numpy.ndarray: The boolean matrix indicating the valid latency map. + """ + + indices_array = np.indices(template.shape) * params["el_spacing"] # convert to (um) + # Calculate the distance from the start point + distances = ( + np.sqrt( + (indices_array[0] - start[0] * params["el_spacing"]) ** 2 + + (indices_array[1] - start[1] * params["el_spacing"]) ** 2 + ) + / 1000000 + ) # distances in m + delay_mat = np.zeros(distances.shape) # Initialize delay matrix + for z in range(distances.shape[2]): + # Calculate the delay matrix based on the distance from the start point in + # the z-axis + delay_mat[:, :, z] = (1 / params["sampling_rate"]) * np.abs((z - start[2])) + + # Set the delay matrix for the starting point to non-zero + delay_mat[:, :, start[2]] = 1 / params["sampling_rate"] + + # Check which velocity is within the maximum velocity + passed = (distances / delay_mat) < params["max_velocity"] # Boolean matrix + return passed + + +def restore_sparse_template(template, spacing=(0.5, 0.5, 0.5)): + """Remove empty electrodes and interpolate the sparse template to fill in the gaps. + Args: + template (numpy.ndarray): The template as a NumPy array. + spacing (tuple): The spacing for interpolation. + Returns: + numpy.ndarray: The interpolated template. + """ + + # Find where the actual template begins (first non-zero electrode) + x, y, z = np.nonzero(template) + # Restrict template to the smallest bounding box + true_template = template[np.min(x) :, np.min(x) :, :] + # Remove empty electrodes (due to sparseness) + del_x = np.delete( + true_template, np.nonzero(np.sum(sel_test, axis=(0, 2)) == 0), axis=1 + ) + del_y = np.delete(del_x, np.nonzero(np.sum(sel_test, axis=(1, 2)) == 0), axis=0) + # Interpolate the template to fill in the gaps + interp_tmp = interpolate_template(del_y, spacing) + return interp_tmp + + def interp_max(x, spacing): + """Calculate the maximum value for interpolation. + Args: + x (numpy.ndarray): The input array. + spacing (int): The spacing for interpolation. + Returns: + int: The maximum value for interpolation. + """ + if len(x) == 1: interp_max = spacing elif spacing == 1: @@ -146,6 +289,16 @@ def interp_max(x, spacing): def interpolate_template( template, spacing=[1, 1, 0.2], template_path=[], overwrite=False ): + """Generate an interpolated template based on the input template and spacing. + Args: + template (numpy.ndarray): The template as a NumPy array. + spacing (list): The spacing for interpolation. + template_path (str): The path to save the interpolated file to (optional). + overwrite (bool): Whether to overwrite the existing interpolated file. + Returns: + numpy.ndarray: The interpolated template. + """ + if template_path: split_path = template_path.split(sep="/") split_path[-1] = "interp_" + split_path[-1] @@ -159,7 +312,6 @@ def interpolate_template( else: x, y, z = [np.arange(template.shape[k]) for k in range(3)] f = RegularGridInterpolator((x, y, z), template) - # new_grid = np.mgrid[0:x[-1]:spacing[0], 0:y[-1]:spacing[1], 0:z[-1]+1:spacing[2]] new_grid = np.mgrid[ 0 : interp_max(x, spacing[0]) : spacing[0], 0 : interp_max(y, spacing[1]) : spacing[1], @@ -177,23 +329,6 @@ def interpolate_template( return interp_template -def valid_latency_map(template, start, params): - indices_array = np.indices(template.shape) * params["el_spacing"] # convert to (um) - distances = ( - np.sqrt( - (indices_array[0] - start[0] * params["el_spacing"]) ** 2 - + (indices_array[1] - start[1] * params["el_spacing"]) ** 2 - ) - / 1000000 - ) # convert to (s) - th_mat = np.zeros(distances.shape) - for z in range(distances.shape[2]): - th_mat[:, :, z] = (params["max_velocity"] / params["sampling_rate"]) * (z + 2) - - passed = distances <= th_mat - return passed - - def cone(matrix_shape, r, apex=[]): if not apex: apex = np.empty(3) @@ -212,26 +347,24 @@ def cone(matrix_shape, r, apex=[]): return cone_matrix -def generate_dilation_structure(max_t, max_r, spacing=1 / 3, sampling_rate=20000): - """ - max_t: numeric - Maximum time [us] to detect a peak from a previous peak - max_r: numeric - Maximum deviation [um] from the peak within max_t to detect the next peak - spacing: numeric - Spacing of the interpolation (if performed before the dilation) +def generate_dilation_structure(max_t, params): + """Generate a cone-shaped dilation structure based on the maximum time and parameters. + Args: + max_t (int): The maximum time. + params (dict): A dictionary containing the parameters for dilation. + Returns: + numpy.ndarray: The generated dilation structure. + """ - el_dist = params["el_spacing"] - frame_time = (1000000 / sampling_rate) * spacing # Assumes 20k sampling rate - t = np.ceil(max_t / frame_time).astype("int16") - r = np.ceil(max_r / el_dist).astype("int16") - d = (2 * r + 1).astype("int16") - - cone_matrix = cone((d, d, t), r) - # x, y, z = np.ogrid[:d, :d, :t] - # cone_equation = (x - (r))**2 + (y - (r))**2 <= r**2 * (1 - (z - 0)/t)**2 - # cone_matrix = np.zeros((d,d,t), dtype=bool) - # cone_matrix[cone_equation] = True + frame_time = 1000000 / params["sampling_rate"] # in us + t = np.ceil(max_t / frame_time).astype("int16") # in samples + max_r = ( + params["max_velocity"] * max_t + ) # Infer the maximum radius from the maximum velocity + r = np.ceil(max_r / params["el_spacing"]).astype("int16") # in grid units + d = (2 * r + 1).astype("int16") # maximum diameter of the cone + + cone_matrix = cone((d, d, t), r) # Generate the cone matrix structure = cone_matrix[:, :, ::-1] structure_base = np.full( @@ -239,21 +372,14 @@ def generate_dilation_structure(max_t, max_r, spacing=1 / 3, sampling_rate=20000 ) structure_init = np.full((structure.shape[0], structure.shape[1], 1), False) structure_init[r, r, 0] = True - full_structure = np.concatenate((structure_base, structure_init, structure), axis=2) + full_structure = np.concatenate( + (structure_base, structure_init, structure), axis=2 + ) # Ensure that dilation only occurs in the future return full_structure -def iterative_dilation( - template, - r_dilation=2, - init_th=-10, - min_th=-1, - filter_footprint=(3, 3, 3), - use_derivative=True, -): - if use_derivative: - template = np.diff(template) +def iterative_dilation(template, params): structure = generate_dilation_structure(r_dilation) m_init = template < init_th # Detection of initial seeds/definitive peaks @@ -308,10 +434,8 @@ def perform_path_qc( vel_range=[0.4, 1], min_length=10, ): - if np.max(np.concatenate(paths)) < 220: - scaled_paths = scale_path_coordinates(paths, params) - else: - scaled_paths = paths + + scaled_paths = ut.convert_coor_scale(paths, params, scale="um") good_path_list, r2s, vels, lengths = [], [], [], [] path_list = [] for path in scaled_paths: @@ -430,34 +554,19 @@ def remove_circulating_paths(path_list, max_duplicate_ratio=0.3): def calculate_path_velocity(path, params): path_diff = np.diff(path, axis=0) + # print(path_diff) # print(path_diff.shape) dist = np.sqrt(path_diff[:, 0] ** 2 + path_diff[:, 1] ** 2) / 1000 time = (np.cumsum(np.abs(path_diff[:, 2])) / params["sampling_rate"]) * 1000 - regressor = sklearn.linear_model.LinearRegression(fit_intercept=False) + regressor = LinearRegression(fit_intercept=True) vel_y = np.cumsum(dist).reshape(-1, 1) vel_x = time.reshape(-1, 1) regressor.fit(vel_x, vel_y) y_pred = regressor.predict(vel_x) - r2 = sklearn.metrics.r2_score(vel_y, y_pred) + r2 = r2_score(vel_y, y_pred) return regressor.coef_[0][0], r2 -def scale_path_coordinates(path_list, params): - scaled_paths = [ - np.concatenate((path[:, :2] * params["el_spacing"], path[:, 2:]), axis=1) - for path in path_list - ] - return scaled_paths - - -def unscale_path_coordinates(path_list, params): - unscaled_paths = [ - np.concatenate((path[:, :2] / params["el_spacing"], path[:, 2:]), axis=1) - for path in path_list - ] - return unscaled_paths - - def path_to_vertices(path_list, params, unscale=True): if unscale: path_list = unscale_path_coordinates(path_list, params) diff --git a/axon_tracking/spike_sorting.py b/axon_tracking/spike_sorting.py index baced3d..7c92c73 100644 --- a/axon_tracking/spike_sorting.py +++ b/axon_tracking/spike_sorting.py @@ -55,11 +55,19 @@ def sort_recording_list( os.path.join(save_root, stream_id, "sorter_output", "amplitudes.npy") ) if not os.path.exists(sorter_output_file): - multirecording, common_el, pos = concatenate_recording_slices( - rec_path, stream_id - ) + # Check if axon scan or network recording + rec_names = list(h5["wells"][stream_id].keys()) + if len(rec_names) > 1: + recording, common_el, pos = concatenate_recording_slices( + rec_path, stream_id + ) + else: + recording = si.MaxwellRecordingExtractor( + rec_path, stream_id=stream_id, rec_name=rec_names[0] + ) + sorting = clean_sorting( - multirecording, + recording, save_root, stream_id, sorter, diff --git a/axon_tracking/template_extraction.py b/axon_tracking/template_extraction.py index 0686d2b..0889a0d 100755 --- a/axon_tracking/template_extraction.py +++ b/axon_tracking/template_extraction.py @@ -75,12 +75,12 @@ def preprocess_sorting(sorting_path, qc_params): # Load sorting sorting = si.KiloSortSortingExtractor(sorting_path) - # Load recording + # Find recording path json_path = os.path.join( Path(sorting_path).parent.absolute(), "spikeinterface_recording.json" ) - # Concatenate recordings to use for splitting the sorting + # Load concatenated recording to use for splitting the sorting multirecording = si.load_extractor(json_path, base_folder=True) # Clean sorting (perform quality control) @@ -90,10 +90,11 @@ def preprocess_sorting(sorting_path, qc_params): ) # Relevant if last spike time == recording_length cleaned_sorting.register_recording(multirecording) - # Split sorting into segments - segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording) + # Split sorting into segments if it is a ConcatenateSegmentRecording + if isinstance(multirecording, si.ConcatenateSegmentRecording): + cleaned_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording) - return segment_sorting + return cleaned_sorting def perform_si_qc(sorting, recording, qc_params): @@ -195,8 +196,7 @@ def extract_all_templates(segment_sorting, te_params): # Find cutout for waveform extraction cutout_samples, cutout_ms = ut.get_cutout_info(full_path) # Find out which well this belongs to - sorting_path = segment_sorting._annotations["phy_folder"] - stream_id = [p for p in sorting_path.split("/") if p.startswith("well")][0] + stream_id = ut.get_sorting_stream_id(segment_sorting) # Find electrode positions rec_names, _, pos = ss.find_common_electrodes(full_path, stream_id) diff --git a/axon_tracking/utils.py b/axon_tracking/utils.py index a777254..3f6447b 100644 --- a/axon_tracking/utils.py +++ b/axon_tracking/utils.py @@ -21,12 +21,36 @@ def infer_stream_id(sorting_path): return stream_id +def get_sorting_stream_id(sorting): + """ + Get the stream ID of the sorting object. + + Args: + sorting (SortingExtractor): Sorting object. + + Returns: + stream_id (str): Stream ID. + """ + try: + stream_id = sorting._kwargs["recording_or_recording_list"][ + 0 + ]._parent_recording._parent_recording.stream_id + except: + stream_id = sorting._recording._kwargs["recording_list"][ + 0 + ]._parent._parent.stream_id + + return stream_id + + def get_cutout_info(rec_path): """Extracts pre and post trigger cutout samples and their corresponding times in milliseconds from a given recording file. + Args: rec_path (str): Path to the HDF5 recording file. + Returns: tuple: A tuple containing: - cutout_samples (list): List of pre and post trigger cutout samples. - cutout_ms (list): List of pre and post trigger cutout times in milliseconds. @@ -54,7 +78,7 @@ def get_cutout_info(rec_path): if cutout_samples[0] < 0: print("Network recording detected, using default [1.5, 5]") cutout_ms = np.array([1.5, 5]) - cutout_samples = cutout_ms * (sampling_rate / 1000) + cutout_samples = (cutout_ms * (sampling_rate / 1000)).astype(int) else: cutout_ms = [ x / (sampling_rate / 1000) for x in cutout_samples @@ -107,3 +131,81 @@ def find_files(save_root, file_name="templates.npy", folder_name="sorter_output" ] return file_list + + +def generate_max_template(template, peak="pos", absolute_value=False): + """ + Generates a maximum template from a given template of shape (x*y*t). + + Args: + template (np.ndarray): Template to generate the maximum template from. + peak (str): Peak to use for the maximum template. Can be "pos" or "neg" or "both". + absolute_value (bool): Whether to return the absolute value. + + Returns: + np.ndarray: Maximum template of shape (x*y*1). + """ + if peak == "pos": + max_template = np.max(template, axis=2) + elif peak == "neg": + max_template = np.min(template, axis=2) + elif peak == "both": + max_template = np.max(template, axis=2) + min_template = np.min(template, axis=2) + diff_idx = np.where(np.abs(min_template) > max_template) + max_template[diff_idx] = min_template[diff_idx] + else: + raise ValueError("Invalid peak value. Must be 'pos', 'neg' or 'both'.") + + if absolute_value: + max_template = np.abs(max_template) + + return max_template + + +def check_if_scaled(path_list, params): + """ + Checks if the coordinates are scaled to um. + + Args: + path_list (list of np.ndarrays): Coordinates to check. + + Returns: + bool: True if the coordinates are scaled, False otherwise. + """ + coors = np.concatenate(path_list)[:, :2] # Concatenate all coordinates + + # Check if all coordinates are multiples of the electrode spacing + return np.all(np.remainder(coors, params["el_spacing"]) == 0) + + +def convert_coor_scale(path_list, params, scale="um"): + """ + Converts the coordinates to a different scale ('um' or 'el'). + + Args: + path_list (list of np.ndarrays): Coordinates to convert. + params (dict): Parameters dictionary containing the electrode spacing. + scale (str): Scale to convert to. Can be "um" or "el". + + Returns: + list of np.ndarrays: List of converted coordinates. + """ + if scale == "um": + if check_if_scaled(path_list, params): # Check if already scaled to um + return path_list + scale_factor = params["el_spacing"] + elif scale == "el": + if not check_if_scaled( + path_list, params + ): # Check if already in electrode spacing + return path_list + scale_factor = 1 / params["el_spacing"] + else: + raise ValueError("Invalid scale value. Must be 'um' or 'el'.") + + rescaled = [ + np.concatenate((path[:, :2] * params["el_spacing"], path[:, 2:]), axis=1) + for path in path_list + ] # Rescale the coordinates without changing the z values + return rescaled diff --git a/axon_tracking/visualization.py b/axon_tracking/visualization.py index f67156d..760d57b 100644 --- a/axon_tracking/visualization.py +++ b/axon_tracking/visualization.py @@ -10,6 +10,7 @@ from skimage.morphology import ball from axon_tracking import skeletonization as skel +from axon_tracking import utils as ut def plot_velocity_qc(vels, r2s, fig_size=(6, 2)): @@ -59,14 +60,14 @@ def plot_delay_skeleton( params, skel_params, figsize=4, - plot_ais=True, + ais=None, plot_ais_connection=True, linewidth=2, font_size=24, ): # path_list = skel.scale_path_coordinates(path_list) - sorted_vertices = skel.path_to_vertices(path_list, params) + sorted_vertices = skel.path_to_vertices(path_list, params, unscale=False) c_max = 1000 * (sorted_vertices[-1][2]) / params["sampling_rate"] fig_ratio = ( np.ptp(sorted_vertices, axis=0)[0] / np.ptp(sorted_vertices, axis=0)[1] + 0.7 @@ -92,25 +93,26 @@ def plot_delay_skeleton( ) # ,format=mticker.FixedFormatter(clb_ticks) clb.set_label(label="Delay (ms)", size=font_size * figsize) clb.ax.tick_params(labelsize=font_size * figsize, length=0) - if plot_ais: + if ais is not None: plt.scatter( - skel_params["ais"][0][0], - skel_params["ais"][0][1], + ais[0], + ais[1], s=50, color="k", zorder=10, ) - if plot_ais_connection: + if plot_ais_connection and ais is not None: plt.plot( - [skel_params["ais"][0][0], sorted_vertices[0][0]], - [skel_params["ais"][0][1], sorted_vertices[0][1]], + [ais[0], sorted_vertices[0][0]], [ais[1], sorted_vertices[0][1]], c="k" ) - # print(sorted_vertices[0][0]) + print(sorted_vertices[0][0]) ax.autoscale_view() + + # ax.set_xlim([0, 440]) + # ax.set_ylim([0, 240]) ax.set_ylim(ax.get_ylim()[::-1]) - # ax.set_xlim([0, 2200]) # ax.set_xlabel("(μm)") # ax.set_ylabel("(μm)") # plt.show() @@ -130,10 +132,11 @@ def plot_conduction_velocity(path_list, params, fig_size=(3, 3)): plt.show() -def plot_template_and_noise(template, noise, th_template): +def plot_template_and_noise(template, noise, th_template, vrange=[-3, 3]): + max_template = ut.generate_max_template(template, peak="both", absolute_value=False) fig, axes = plt.subplots(1, 3, figsize=(21, 7)) plt.subplot(131) - plt.imshow(np.min(template, axis=2).T, vmin=-10, vmax=0) + plt.imshow(max_template.T, vmin=vrange[0], vmax=vrange[1], cmap="coolwarm") plt.colorbar(shrink=0.3) plt.subplot(132) @@ -156,11 +159,12 @@ def plot_template_overview( unit_ids=None, overwrite=False, ): - full_filename = os.path.join(root_path, filename + ".png") + full_filename = os.path.join(root_path, "templates", filename + ".png") if os.path.exists(full_filename) and not overwrite: display(Image(filename=full_filename)) else: - files = os.listdir(root_path) + template_folder = os.path.join(root_path, "templates") + files = os.listdir(template_folder) template_files = [f for f in files if "_" not in f] ids = [float(t.split(".")[0]) for t in template_files] if unit_ids is not None: @@ -172,7 +176,7 @@ def plot_template_overview( n_rows = int(np.ceil(len(template_files) / n_cols)) fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 3 * n_rows)) for i, template_file in enumerate(template_files): - template_path = os.path.join(root_path, template_file) + template_path = os.path.join(template_folder, template_file) template = np.load(template_path) temp_diff = np.diff(template) tmp_filt = nd.gaussian_filter(temp_diff, sigma=1) @@ -223,7 +227,7 @@ def plot_skeleton(skeleton, x_lim=[], y_lim=[], fig_size=5, marker_size=10, ais= def generate_propagation_gif( template, params, - cumulative=True, + cumulative=False, vertices=[], downsample=2, clim=[-10, 0], @@ -234,16 +238,16 @@ def generate_propagation_gif( el_offset = params["el_spacing"] interp_offset = el_offset * spacing xticks = np.arange(0, 3850, 500) - yticks = np.arange(0, 2200, 500) + yticks = np.arange(0, 2100, 500) conv_xticks = xticks / interp_offset conv_yticks = yticks / interp_offset if len(vertices) > 0: x, y, z = [vertices[:, x] for x in range(3)] - clb_ticks = [ - "0", - str(np.round(max(z) / (params["sampling_rate"] / 1000), decimals=1)), - ] + clb_ticks = [ + "0", + str(np.round(max(z) / (params["sampling_rate"] / 1000), decimals=1)), + ] ims = [] fig = plt.figure() # added @@ -289,7 +293,7 @@ def plot_filled_contour( ): interp_tmp = skel.interpolate_template(capped_template, spacing=params["upsample"]) interp_tmp = nd.gaussian_filter(interp_tmp, sigma=0.8) - sorted_vertices = skel.path_to_vertices(skeleton.paths(), params) + sorted_vertices = skel.path_to_vertices(skeleton.paths(), params, unscale=False) skel_mat = np.zeros(interp_tmp.shape) skel_mat[tuple(sorted_vertices.astype("int").T)] = True dil_mat = nd.binary_dilation(skel_mat, structure=ball(radius)) diff --git a/notebooks/part2_template-extraction.ipynb b/notebooks/part2_template-extraction.ipynb index c0328e1..a1d4b30 100755 --- a/notebooks/part2_template-extraction.ipynb +++ b/notebooks/part2_template-extraction.ipynb @@ -43,7 +43,7 @@ "qc_params = dict()\n", "qc_params['min_n_spikes'] = 1500 #Minimum number of spikes to be detected for a unit for template extraction to take place\n", "qc_params['exclude_mua'] = True #Exclude units that were labelled multi unit activity by kilosort\n", - "qc_params['use_bc'] = False #Use bombcell for QC\n", + "qc_params['use_bc'] = True #Use bombcell for QC\n", "qc_params['use_si'] = True #Use spikeinterface for QC\n", "qc_params[\"auto_merge\"] = True #Automatically merge units (spikeinterface implementation)\n", "qc_params[\"remove_redundant\"] = True #Remove redundant units (spikeinterface implementation)" diff --git a/notebooks/part3_skeletonization.ipynb b/notebooks/part3_skeletonization.ipynb index 448b633..94763aa 100755 --- a/notebooks/part3_skeletonization.ipynb +++ b/notebooks/part3_skeletonization.ipynb @@ -33,73 +33,125 @@ "sys.path.append(\"/home/phornauer/Git/axon_tracking/\")\n", "import axon_tracking.skeletonization as skel\n", "import axon_tracking.template_extraction as te\n", - "import axon_tracking.visualization as vis" + "import axon_tracking.visualization as vis\n", + "import axon_tracking.utils as ut" ] }, { "cell_type": "code", "execution_count": null, - "id": "d355a47c-2fe1-4cfb-aa7d-95dfb40e823a", + "id": "c325b7dd-9de3-42c2-8f3a-59c9231ed695", + "metadata": {}, + "outputs": [], + "source": [ + "root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Torsten_2/241021/T002523/AxonTracking'\n", + "stream_id = 'well005'\n", + "suffix = \"sorter_output\" \n", + "overview_path = os.path.join(root_path, stream_id, suffix) \n", + "vis.plot_template_overview(overview_path, n_cols = 3, vmin=-10, vmax=0, filename='overview', unit_ids= None, overwrite=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92b63dfd", "metadata": {}, "outputs": [], "source": [ "params = dict()\n", - "params['noise_threshold'] = -1.5\n", - "params['abs_threshold'] = -0.2\n", - "params['upsample'] = [0.5,0.5,0.5]\n", - "params['sampling_rate'] = 20000/params['upsample'][2] # [Hz]\n", - "params['ms_cutout'] = [1.5, 5.0]\n", - "params['filter_footprint'] = ball(1)#np.ones((2,2,3))\n", - "params['max_velocity'] = 5 # [m/s]\n", - "params['max_noise_level'] = 0.25\n", - "params['el_spacing'] = 17.5 * params['upsample'][2]" + "params['noise_threshold'] = -0.5 # Noise threshold as factor of the noise level (None to skip)\n", + "params['abs_threshold'] = -0.15 # Absolut threshold in [uV]/sample\n", + "params['upsample'] = [0.5,0.5,0.25] # Spacing for interpolation in [x,y,z] direction\n", + "params['sampling_rate'] = 10000/params['upsample'][2] # [Hz]\n", + "params['ms_cutout'] = [3.0, 5.0] # Cutout in ms before and after the spike\n", + "#params['filter_footprint'] = ball(1)\n", + "params['max_velocity'] = 0.8 # [m/s]\n", + "params['max_noise_level'] = 0.25 # Mean of the noise level/matrix\n", + "params['el_spacing'] = 17.5 * params['upsample'][0] # Electrode spacing in um\n", + "params['ais_detection'] = \"dev\" # \"time\" or \"dev\" or \"amp\" or None\n", + "params['buffer_frames'] = 10 # Buffer frames for the ais cutout" ] }, { "cell_type": "code", "execution_count": null, - "id": "c325b7dd-9de3-42c2-8f3a-59c9231ed695", + "id": "9c7d393c-5bc7-4ad0-afa8-9801d58efdb3", "metadata": {}, "outputs": [], "source": [ - "root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Large_dose_range/concatenated/'\n", - "#root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics_2/Week_3/concatenated/'\n", - "stream_id = 'well007'" + "root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Torsten_2/241021/T002523/AxonTracking'\n", + "stream_id = 'well011'\n", + "suffix = \"sorter_output\" # Typically the sorter_output folder\n", + "template_folder = \"templates\" # Subfolder containing the templates\n", + "template_id = 566\n", + "full_path = os.path.join(root_path, stream_id, suffix, template_folder) \n", + "template, template_save_file = skel.load_template_file(full_path, template_id)" ] }, { "cell_type": "code", "execution_count": null, - "id": "144d37a1-7304-4621-af47-4d20d3a06c40", + "id": "1a4b5e03", "metadata": {}, "outputs": [], "source": [ - "vis.plot_template_overview(root_path, stream_id, params, n_cols = 3, vmin=-10, vmax=0, filename='test',unit_ids= [],overwrite=False)" + "interp_temp, interp_noise, ais = skel.preprocess_template(template, params)" ] }, { "cell_type": "code", "execution_count": null, - "id": "9c7d393c-5bc7-4ad0-afa8-9801d58efdb3", + "id": "04e1b8df", "metadata": {}, "outputs": [], "source": [ - "root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Large_dose_range/concatenated/'\n", - "#root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics_2/Week_3/concatenated/'\n", - "stream_id = 'well014'\n", - "template_id = 274\n", - "template, template_save_file, noise = skel.load_template_file(root_path, stream_id, template_id)\n", - "temp_diff = np.diff(template)\n", - "capped_template, target_coor= skel.localize_neurons(temp_diff, ms_cutout=params['ms_cutout'])\n", - "\n", - "#tmp_filt = nd.median_filter(capped_template,footprint=params['filter_footprint'])\n", - "tmp_filt = nd.gaussian_filter(capped_template,sigma=1)\n", - "\n", - "interp_temp = skel.interpolate_template(tmp_filt,spacing=params['upsample'],overwrite=False)\n", - "noise = skel.interpolate_template(noise,spacing=params['upsample'][:2] +[1])\n", - "th_template = skel.threshold_template(interp_temp, noise, target_coor, params)\n", - "\n", - "vis.plot_template_and_noise(tmp_filt,noise,th_template)" + "plt.hist(interp_noise.flatten(), bins=100)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b489155", + "metadata": {}, + "outputs": [], + "source": [ + "params['max_velocity'] = 0.8 # [m/s]\n", + "params['abs_threshold'] = -0.1\n", + "params['noise_threshold'] = -1\n", + "th_template = skel.threshold_template(interp_temp, interp_noise, ais, params)\n", + "vis.plot_template_and_noise(interp_temp, interp_noise, th_template, vrange=[-0.5, 0.5])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf0a9034", + "metadata": {}, + "outputs": [], + "source": [ + "struct = skel.generate_dilation_structure(500,params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b320dce7", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(np.sum(struct,axis=0),clim=[0,1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f2a0f7d", + "metadata": {}, + "outputs": [], + "source": [ + "ani = vis.generate_propagation_gif(interp_temp, params = params, cumulative=False, downsample=4, clim=[-1,-0.01], cmap=\"Greys\", marker_size=0.1,spacing=1)\n", + "HTML(ani.to_jshtml())" ] }, { @@ -110,10 +162,10 @@ "outputs": [], "source": [ "skel_params = dict()\n", - "skel_params['scale'] = 1#1\n", - "skel_params['const'] = 2#3\n", + "skel_params['scale'] = 2#1\n", + "skel_params['const'] = 10#3\n", "skel_params['dust_threshold'] = 10\n", - "skel_params['anisotropy'] = (10.0, 10.0, 10.0)\n", + "skel_params['anisotropy'] = (10,10,10)#(params[\"el_spacing\"], params[\"el_spacing\"], 0.0001)\n", "skel_params['tick_threshold'] = 10\n", "skel_params['n_jobs'] = 16" ] @@ -128,11 +180,11 @@ "t_cap = [0, th_template.shape[2]] #in samples\n", "skels = skel.skeletonize(th_template[:,:,t_cap[0]:t_cap[1]].astype('bool'), **skel_params)\n", "\n", - "skeleton = kimimaro.join_close_components(skels[1], radius=10)\n", + "skeleton = kimimaro.join_close_components(skels[1], radius=50)\n", "\n", "skeleton = kimimaro.postprocess(skeleton, tick_threshold=5, dust_threshold=10)\n", "\n", - "skeleton = kimimaro.join_close_components(skeleton, radius=20)\n", + "#skeleton = kimimaro.join_close_components(skeleton, radius=20)\n", "\n", "#skeleton = kimimaro.postprocess(skeleton, tick_threshold=5, dust_threshold=100)\n", "\n", @@ -140,6 +192,96 @@ "skel_paths = cv.Skeleton.paths(skeleton)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b229274", + "metadata": {}, + "outputs": [], + "source": [ + "skeleton = skeleton.average_smoothing(n=4,check_boundary=False) #Interferes with scaling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e563c07", + "metadata": {}, + "outputs": [], + "source": [ + "skel_paths = cv.Skeleton.paths(skeleton)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e04572a", + "metadata": {}, + "outputs": [], + "source": [ + "skel_paths_scaled = ut.convert_coor_scale(skel_paths, params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcb4095b", + "metadata": {}, + "outputs": [], + "source": [ + "coors = np.concatenate(skel_paths)[:,:2] #Concatenate all coordinates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f631054", + "metadata": {}, + "outputs": [], + "source": [ + "coors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4531d57", + "metadata": {}, + "outputs": [], + "source": [ + "coors_scaled = np.concatenate(skel_paths_scaled)[:,:2] #Concatenate all coordinates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbb2d5c8", + "metadata": {}, + "outputs": [], + "source": [ + "np.unique(np.remainder(coors_scaled,params[\"el_spacing\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b0d3c7c", + "metadata": {}, + "outputs": [], + "source": [ + "params[\"el_spacing\"]/ 0.02185059" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "776b7e6f", + "metadata": {}, + "outputs": [], + "source": [ + "np.unique(np.remainder(coors,params[\"el_spacing\"]))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -147,7 +289,7 @@ "metadata": {}, "outputs": [], "source": [ - "vis.plot_delay_skeleton(skel_paths, params, skel_params,figsize=4, plot_ais=False, plot_ais_connection=False)" + "vis.plot_delay_skeleton(skel_paths, params, skel_params,figsize=0.5, ais=None, plot_ais_connection=False)" ] }, { @@ -170,8 +312,8 @@ "qc_params = dict()\n", "qc_params['window_size'] = 5\n", "qc_params['max_duplicate_ratio'] = 0.3\n", - "qc_params['min_r2'] = 0.8\n", - "qc_params['vel_range'] = [0.2, 0.8]\n", + "qc_params['min_r2'] = 0.9\n", + "qc_params['vel_range'] = [0.3, 0.8]\n", "qc_params['min_length'] = 1" ] }, @@ -189,114 +331,171 @@ { "cell_type": "code", "execution_count": null, - "id": "f2b1e7f4-4101-4372-88f2-be2c69bedffb", + "id": "8c50b302-5e7f-4563-95d8-687de979b786", "metadata": {}, "outputs": [], "source": [ - "vis.plot_conduction_velocity(scaled_qc_list,params)" + "vis.plot_velocity_qc(vels,r2s,fig_size=(6,2))" ] }, { "cell_type": "code", "execution_count": null, - "id": "8c50b302-5e7f-4563-95d8-687de979b786", + "id": "85ff23a2-4159-4c0f-8dcc-635441ee5d0c", "metadata": {}, "outputs": [], "source": [ - "vis.plot_velocity_qc(vels,r2s,fig_size=(6,2))" + "vis.plot_delay_skeleton(scaled_qc_list, params, skel_params, figsize=0.5, ais=None, plot_ais_connection=False)" ] }, { "cell_type": "code", "execution_count": null, - "id": "85ff23a2-4159-4c0f-8dcc-635441ee5d0c", + "id": "525d2080-0f21-46be-8ee8-ca827183928d", "metadata": {}, "outputs": [], "source": [ - "vis.plot_delay_skeleton(scaled_qc_list, params, skel_params,figsize=4, plot_ais=False, plot_ais_connection=False)" + "test_skels = [skeleton.from_path(x) for x in scaled_qc_list]" ] }, { "cell_type": "code", "execution_count": null, - "id": "525d2080-0f21-46be-8ee8-ca827183928d", + "id": "e8194193", "metadata": {}, "outputs": [], "source": [ - "test_skels = [skeleton.from_path(x) for x in scaled_qc_list]" + "branch_points = skeleton.branches()\n", + "path_list = skeleton.paths()\n", + "all_branches = []\n", + "for path in path_list:\n", + " path_idx, _ = np.where(\n", + " (path[:, None] == skeleton.vertices[branch_points]).all(axis=2)\n", + " )\n", + " branch_idx = []" ] }, { "cell_type": "code", "execution_count": null, - "id": "ea55d511-7a15-4fc6-91b1-6648d8d52475", + "id": "18b6e9d0", "metadata": {}, "outputs": [], "source": [ - "test_skeleton = skeleton.simple_merge(test_skels)\n", - "test_skeleton = kimimaro.join_close_components(test_skeleton,radius=100)\n", - "test_skeleton = kimimaro.postprocess(test_skeleton, dust_threshold=500, tick_threshold=10)\n", - "test_skeleton = test_skeleton.consolidate()" + "path_idx, _ = np.where((path_list[4][:, None] == skeleton.vertices[branch_points]).all(axis=2))\n", + "path_idx" ] }, { "cell_type": "code", "execution_count": null, - "id": "3f073e13-8457-4741-a1f6-37431ac54f61", + "id": "67355788", "metadata": {}, "outputs": [], "source": [ - "vis.plot_delay_skeleton(test_skeleton.paths(), params, skel_params,figsize=4, plot_ais=False, plot_ais_connection=False)" + "(path[:, None] == skeleton.vertices[branch_points]).all(axis=2)" ] }, { "cell_type": "code", "execution_count": null, - "id": "5a6d9626-5b6d-4b38-97d1-6b92c3221371", + "id": "9b9a3821", "metadata": {}, "outputs": [], "source": [ - "all_branches = skel.branches_from_paths(test_skeleton)\n", - "scaled_qc_list, r2s, vels, lengths = skel.perform_path_qc(all_branches, params,**qc_params)" + "skeleton.vertices[branch_points]" ] }, { "cell_type": "code", "execution_count": null, - "id": "50108db9-0a55-41ea-bd98-b1dd5b013a57", + "id": "b74a971c", "metadata": {}, "outputs": [], "source": [ - "vis.plot_velocity_qc(vels,r2s,fig_size=(6,2))" + "from scipy.spatial.distance import pdist" ] }, { "cell_type": "code", "execution_count": null, - "id": "ae2187a9-b682-4d97-a721-cf88f7be5bb8", + "id": "22e90832", "metadata": {}, "outputs": [], "source": [ - "test_skeleton.vertices" + "for p in path_idx:\n", + " # check_idx = range(p-1,p+2)\n", + " dists = pdist(path[p - 1 : p + 2, :])\n", + " branch_idx = [0]\n", + " if dists[0] > dists[2]:\n", + " branch_idx.append(p)\n", + " else:\n", + " branch_idx.append(p + 1)" ] }, { "cell_type": "code", "execution_count": null, - "id": "d5c00b8f-0c10-40bf-b00e-cad97812a2ec", + "id": "8fe33696", + "metadata": {}, + "outputs": [], + "source": [ + "dists" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea55d511-7a15-4fc6-91b1-6648d8d52475", "metadata": {}, "outputs": [], "source": [ - "ani = vis.generate_propagation_gif(interp_temp, cumulative=False, vertices=skeleton.vertices, downsample=1, clim=[-2, -0.15], cmap=\"Greys\", marker_size=0.1)" + "test_skeleton = skeleton.simple_merge(test_skels)\n", + "test_skeleton = kimimaro.join_close_components(test_skeleton,radius=100)\n", + "test_skeleton = kimimaro.postprocess(test_skeleton, dust_threshold=500, tick_threshold=10)\n", + "test_skeleton = test_skeleton.consolidate()" ] }, { "cell_type": "code", "execution_count": null, - "id": "ceb7adb7-3118-4c40-b097-5df64a8aef1e", + "id": "3f073e13-8457-4741-a1f6-37431ac54f61", + "metadata": {}, + "outputs": [], + "source": [ + "ais = np.array(target_coor) * params[\"el_spacing\"]\n", + "vis.plot_delay_skeleton(test_skeleton.paths(), params, skel_params,figsize=0.5, ais=ais, plot_ais_connection=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a6d9626-5b6d-4b38-97d1-6b92c3221371", + "metadata": {}, + "outputs": [], + "source": [ + "all_branches = skel.branches_from_paths(test_skeleton)\n", + "scaled_qc_list, r2s, vels, lengths = skel.perform_path_qc(all_branches, params,**qc_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50108db9-0a55-41ea-bd98-b1dd5b013a57", + "metadata": {}, + "outputs": [], + "source": [ + "vis.plot_velocity_qc(vels,r2s,fig_size=(6,2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5c00b8f-0c10-40bf-b00e-cad97812a2ec", "metadata": {}, "outputs": [], "source": [ + "ani = vis.generate_propagation_gif(interp_temp, params = params, cumulative=False, vertices=skeleton.vertices, downsample=2, clim=[-2, -0.1], cmap=\"Greys\", marker_size=0.1)\n", "HTML(ani.to_jshtml())" ] }, @@ -517,7 +716,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "si101", "language": "python", "name": "python3" }, @@ -531,7 +730,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/notebooks/test_sorting_si101.ipynb b/notebooks/test_sorting_si101.ipynb index 614e670..488296d 100644 --- a/notebooks/test_sorting_si101.ipynb +++ b/notebooks/test_sorting_si101.ipynb @@ -57,25 +57,7 @@ "metadata": {}, "outputs": [], "source": [ - "sorting_list = ['/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Torsten_2/241010/T002523/AxonTracking/well001/sorter_output']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "si.set_global_job_kwargs(n_jobs=te_params[\"n_jobs\"], progress_bar=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "segment_sorting = te.preprocess_sorting(sorting_list[0], qc_params)" + "sorting_list = ['/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Mea1k/phornauer/Slices/230711/Rec5/well000/sorter_output']" ] }, { @@ -93,8 +75,7 @@ "metadata": {}, "outputs": [], "source": [ - "test = np.load(\n", - " '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Torsten_2/241010/T002523/AxonTracking/well000/sorter_output/templates/18.npy')" + "sorting = si.KiloSortSortingExtractor(sorting_list[0])" ] }, { @@ -103,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "plt.imshow(np.min(test, axis=2).T, aspect='auto')" + "rec_path = ss.get_recording_path(sorting)" ] }, { @@ -112,11 +93,7 @@ "metadata": {}, "outputs": [], "source": [ - "merge_unit_groups = si.get_potential_auto_merge(\n", - " full_analyzer,\n", - " resolve_graph=True\n", - " )\n", - "analyzer_merged = full_analyzer.merge_units(merge_unit_groups=merge_unit_groups)" + "from axon_tracking import visualization as vis #Separate as it is not needed for the template extraction" ] }, { @@ -125,108 +102,10 @@ "metadata": {}, "outputs": [], "source": [ - "segment_sorting._annotations['phy_folder']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for sorting_path in tqdm(sorting_list):\n", - " output_path = os.path.join(sorting_path, \"sorter_output\")\n", - " sorting = si.KiloSortSortingExtractor(output_path)\n", - " json_path = os.path.join(sorting_path, \"spikeinterface_recording.json\")\n", - " multirecording = si.load_extractor(json_path, base_folder=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rec_path = ss.get_recording_path(multirecording)\n", - "stream_id = [p for p in sorting_path.split(\"/\") \n", - " if p.startswith(\"well\")][0] # Find out which well this belongs to\n", + "root_path = sorting_path\n", + "plot_path = os.path.join(root_path, 'templates')\n", "\n", - "rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id)\n", - "cleaned_sorting = te.select_good_units(sorting, **qc_params)\n", - "cleaned_sorting = si.remove_excess_spikes(\n", - " cleaned_sorting, multirecording\n", - ") # Relevant if last spike time == recording_length\n", - "cleaned_sorting.register_recording(multirecording)\n", - "segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "stream_id = [p for p in output_path.split(\"/\") \n", - " if p.startswith(\"well\")][0] # Find out which well this belongs to" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "full_path = ss.get_recording_path(segment_sorting)\n", - "cutout_samples, cutout_ms = te.get_assay_information(full_path)\n", - "overwrite_wf = te_params[\"overwrite_wf\"]\n", - "cutout = cutout_ms\n", - "n_jobs = te_params[\"n_jobs\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "si.set_global_job_kwargs(n_jobs=n_jobs, progress_bar=False)\n", - "full_analyzer = si.create_sorting_analyzer(\n", - " sorting=cleaned_sorting,\n", - " recording=multirecording\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "full_analyzer.compute([\"random_spikes\",\n", - " \"waveforms\",\n", - " \"templates\",\n", - " \"spike_amplitudes\",\n", - " \"unit_locations\",\n", - " \"template_similarity\",\n", - " \"correlograms\"],\n", - " extension_params={\n", - " \"random_spikes\":{\"max_spikes_per_unit\":900},\n", - " \"waveforms\":{\"ms_before\":cutout[0], \"ms_after\":cutout[1]}\n", - " }\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "merge_unit_groups = si.get_potential_auto_merge(\n", - " full_analyzer,\n", - " resolve_graph=True\n", - " )\n", - "analyzer_merged = full_analyzer.merge_units(merge_unit_groups=merge_unit_groups)" + "vis.plot_template_overview(plot_path,vmax=5,overwrite=True)" ] }, { @@ -235,234 +114,15 @@ "metadata": {}, "outputs": [], "source": [ - "removed_sorting = si.remove_redundant_units(\n", - " analyzer_merged,\n", - " duplicate_threshold=0.8,\n", - " remove_strategy=\"minimum_shift\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "analyzer_merged.compute(\"noise_levels\")\n", - "metrics = si.compute_quality_metrics(analyzer_merged, n_jobs=n_jobs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metrics.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/AxonScan/Test'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sel_unit_ids = removed_sorting.get_unit_ids()\n", - "template_save_path = os.path.join(output_path, \"templates\")\n", - "if not os.path.exists(template_save_path):\n", - " os.makedirs(template_save_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "h5 = h5py.File(full_path)\n", - "rec_names = list(h5[\"wells\"][stream_id].keys())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.scatter(metrics[\"firing_rate\"],metrics[\"rp_contamination\"])\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "h5 = h5py.File(full_path)\n", - "rec_names = list(h5[\"wells\"][stream_id].keys())\n", - "n_units = cleaned_sorting.get_num_units()\n", - "template_matrix = np.full([n_units, sum(cutout_samples), 26400], np.nan)\n", - "\n", - "for sel_idx, rec_name in enumerate(rec_names):\n", - " rec = si.MaxwellRecordingExtractor(\n", - " full_path, stream_id=stream_id, rec_name=rec_name\n", - " )\n", - " \n", - " rec_centered = si.highpass_filter(rec, freq_min=150)\n", - " \n", - " seg_sort = si.SelectSegmentSorting(segment_sorting, sel_idx)\n", - " seg_sort = si.remove_excess_spikes(seg_sort, rec_centered)\n", - " seg_sort.register_recording(rec_centered) \n", - " \n", - " analyzer = si.create_sorting_analyzer(\n", - " sorting=seg_sort,\n", - " recording=rec_centered,\n", - " sparse=False,\n", - " overwrite=overwrite_wf\n", - " )\n", - "\n", - " analyzer.compute(\"random_spikes\",n_jobs=n_jobs,max_spikes_per_unit=900)\n", - " analyzer.compute(\"waveforms\",ms_before=cutout[0], ms_after=cutout[1],n_jobs=n_jobs)\n", - " analyzer.compute(\"templates\",n_jobs=n_jobs)\n", - " tmp = analyzer.get_extension(\n", - " extension_name=\"templates\"\n", - " )\n", - " tmp_data = tmp.get_data()\n", - " \n", - " els = rec.get_property(\"contact_vector\")[\"electrode\"]\n", - " template_matrix[:, :, els] = tmp_data \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rec.get_property(\"contact_vector\")[\"device_channel_indices\"].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.plot(np.squeeze(template_matrix[0, :, :]))\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metrics.iloc[1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grid = te.convert_to_grid(template_matrix[1,:,:], pos)\n", - "fig, ax = plt.subplots()\n", - "ax.imshow(np.max(np.abs(grid),axis=2).T,vmax=20)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "noise_tmp = np.squeeze(template_matrix[1,:,:])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "th = -5\n", - "tmp_th = noise_tmp