diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8778eff --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 22.8.0 + hooks: + - id: black + exclude: ^docs/ diff --git a/nlb_tools/evaluation.py b/nlb_tools/evaluation.py index 6177e02..5b50ef0 100644 --- a/nlb_tools/evaluation.py +++ b/nlb_tools/evaluation.py @@ -7,8 +7,10 @@ from sklearn.model_selection import GridSearchCV import logging + logger = logging.getLogger(__name__) + def evaluate(test_annotation_file, user_submission_file): """ Runs evaluation as it would be run on EvalAI servers @@ -16,10 +18,10 @@ def evaluate(test_annotation_file, user_submission_file): Parameters ---------- test_annotation_file : str or dict - Path to the eval target .h5 file or dict of eval target + Path to the eval target .h5 file or dict of eval target data to evaluate against user_submission_file : str - Path to the .h5 file or dict with user + Path to the .h5 file or dict with user rate predictions Returns @@ -32,66 +34,100 @@ def evaluate(test_annotation_file, user_submission_file): # define prefixes for scaling metrics scaling_tcount = { - 'mc_maze_large': '[500]', - 'mc_maze_medium': '[250]', - 'mc_maze_small': '[100]', + "mc_maze_large": "[500]", + "mc_maze_medium": "[250]", + "mc_maze_small": "[100]", } # read data from files if type(test_annotation_file) == str: - target_data = h5py.File(test_annotation_file, 'r') + target_data = h5py.File(test_annotation_file, "r") else: target_data = test_annotation_file if type(user_submission_file) == str: - user_data = h5py.File(user_submission_file, 'r') + user_data = h5py.File(user_submission_file, "r") else: user_data = user_submission_file - + result_list = [] scaling_dict = {} scaling_dict_20 = {} # evaluate on datasets that are included in both submission and evaluation data - for dataset in ['mc_maze', 'mc_rtt', 'area2_bump', 'dmfc_rsg', 'mc_maze_large', 'mc_maze_medium', 'mc_maze_small']: - for bin_size_ms, suf in zip([5, 20], ['', '_20']): + for dataset in [ + "mc_maze", + "mc_rtt", + "area2_bump", + "dmfc_rsg", + "mc_maze_large", + "mc_maze_medium", + "mc_maze_small", + ]: + for bin_size_ms, suf in zip([5, 20], ["", "_20"]): if (dataset + suf) not in user_data.keys(): continue dataset_name = dataset + suf logger.info(f"Evaluating {dataset_name}") result_dict = {} # check that both submission and evaluation dicts have data for this dataset - if 'eval_rates_heldout' not in user_data[dataset_name].keys(): + if "eval_rates_heldout" not in user_data[dataset_name].keys(): continue - elif (dataset_name) not in target_data.keys() or 'eval_spikes_heldout' not in target_data[dataset_name].keys(): + elif ( + dataset_name + ) not in target_data.keys() or "eval_spikes_heldout" not in target_data[ + dataset_name + ].keys(): logger.warning(f"Evaluation data for {dataset_name} not found") continue # extract evaluation data - eval_spikes_heldout = target_data[dataset_name]['eval_spikes_heldout'][()].astype('float') - train_behavior = target_data[dataset_name]['train_behavior'][()].astype('float') - eval_behavior = target_data[dataset_name]['eval_behavior'][()].astype('float') + eval_spikes_heldout = target_data[dataset_name]["eval_spikes_heldout"][ + () + ].astype("float") + train_behavior = target_data[dataset_name]["train_behavior"][()].astype( + "float" + ) + eval_behavior = target_data[dataset_name]["eval_behavior"][()].astype( + "float" + ) # extract submitted data - eval_rates_heldin = user_data[dataset_name]['eval_rates_heldin'][()].astype('float') - eval_rates_heldout = user_data[dataset_name]['eval_rates_heldout'][()].astype('float') - eval_rates = np.concatenate([eval_rates_heldin, eval_rates_heldout], axis=-1) + eval_rates_heldin = user_data[dataset_name]["eval_rates_heldin"][()].astype( + "float" + ) + eval_rates_heldout = user_data[dataset_name]["eval_rates_heldout"][ + () + ].astype("float") + eval_rates = np.concatenate( + [eval_rates_heldin, eval_rates_heldout], axis=-1 + ) # calculate co-smoothing bits per spike - result_dict['co-bps'] = float(bits_per_spike(eval_rates_heldout, eval_spikes_heldout)) + result_dict["co-bps"] = float( + bits_per_spike(eval_rates_heldout, eval_spikes_heldout) + ) - if dataset == 'dmfc_rsg': + if dataset == "dmfc_rsg": # Compute Pearson's r for the correlation between neural speed and tp result_dict["tp corr"] = speed_tp_correlation( eval_spikes_heldout, eval_rates, eval_behavior ) else: # extract train rates for regression - train_rates_heldin = user_data[dataset_name]['train_rates_heldin'][()].astype('float') - train_rates_heldout = user_data[dataset_name]['train_rates_heldout'][()].astype('float') - train_rates = np.concatenate([train_rates_heldin, train_rates_heldout], axis=-1) + train_rates_heldin = user_data[dataset_name]["train_rates_heldin"][ + () + ].astype("float") + train_rates_heldout = user_data[dataset_name]["train_rates_heldout"][ + () + ].astype("float") + train_rates = np.concatenate( + [train_rates_heldin, train_rates_heldout], axis=-1 + ) # make decode mask if not provided - if 'train_decode_mask' in target_data[dataset_name].keys(): - train_decode_mask = target_data[dataset_name]['train_decode_mask'][()] - eval_decode_mask = target_data[dataset_name]['eval_decode_mask'][()] + if "train_decode_mask" in target_data[dataset_name].keys(): + train_decode_mask = target_data[dataset_name]["train_decode_mask"][ + () + ] + eval_decode_mask = target_data[dataset_name]["eval_decode_mask"][()] else: train_decode_mask = np.full(train_rates.shape[0], True)[:, None] eval_decode_mask = np.full(eval_rates.shape[0], True)[:, None] @@ -103,42 +139,59 @@ def evaluate(test_annotation_file, user_submission_file): eval_behavior, eval_decode_mask, ) - if 'psth' in target_data[dataset_name].keys(): + if "psth" in target_data[dataset_name].keys(): # get PSTH information and evaluate - psth = target_data[dataset_name]['psth'][()].astype('float') - eval_cond_idx = target_data[dataset_name]['eval_cond_idx'][()] - if 'eval_jitter' in target_data[dataset_name].keys(): - jitter = target_data[dataset_name]['eval_jitter'][()] + psth = target_data[dataset_name]["psth"][()].astype("float") + eval_cond_idx = target_data[dataset_name]["eval_cond_idx"][()] + if "eval_jitter" in target_data[dataset_name].keys(): + jitter = target_data[dataset_name]["eval_jitter"][()] else: jitter = np.zeros(eval_rates.shape[0]).astype(int) psth_r2 = eval_psth(psth, eval_rates, eval_cond_idx, jitter=jitter) result_dict["psth R2"] = float(psth_r2) - if 'eval_rates_heldin_forward' in user_data[dataset_name].keys() and 'eval_spikes_heldin_forward' in target_data[dataset_name].keys(): + if ( + "eval_rates_heldin_forward" in user_data[dataset_name].keys() + and "eval_spikes_heldin_forward" in target_data[dataset_name].keys() + ): # extract forward prediction data - eval_spikes_heldin_forward = target_data[dataset_name]['eval_spikes_heldin_forward'][()].astype('float') - eval_spikes_heldout_forward = target_data[dataset_name]['eval_spikes_heldout_forward'][()].astype('float') - eval_rates_heldin_forward = user_data[dataset_name]['eval_rates_heldin_forward'][()].astype('float') - eval_rates_heldout_forward = user_data[dataset_name]['eval_rates_heldout_forward'][()].astype('float') + eval_spikes_heldin_forward = target_data[dataset_name][ + "eval_spikes_heldin_forward" + ][()].astype("float") + eval_spikes_heldout_forward = target_data[dataset_name][ + "eval_spikes_heldout_forward" + ][()].astype("float") + eval_rates_heldin_forward = user_data[dataset_name][ + "eval_rates_heldin_forward" + ][()].astype("float") + eval_rates_heldout_forward = user_data[dataset_name][ + "eval_rates_heldout_forward" + ][()].astype("float") # combine held-in and held-out - eval_spikes_forward = np.dstack([eval_spikes_heldin_forward, eval_spikes_heldout_forward]) - eval_rates_forward = np.dstack([eval_rates_heldin_forward, eval_rates_heldout_forward]) + eval_spikes_forward = np.dstack( + [eval_spikes_heldin_forward, eval_spikes_heldout_forward] + ) + eval_rates_forward = np.dstack( + [eval_rates_heldin_forward, eval_rates_heldout_forward] + ) # calculate forward prediction bits per spike - result_dict['fp-bps'] = float(bits_per_spike(eval_rates_forward, eval_spikes_forward)) + result_dict["fp-bps"] = float( + bits_per_spike(eval_rates_forward, eval_spikes_forward) + ) - if dataset in ['mc_maze_large', 'mc_maze_medium', 'mc_maze_small']: - sd = scaling_dict if suf == '' else scaling_dict_20 + if dataset in ["mc_maze_large", "mc_maze_medium", "mc_maze_small"]: + sd = scaling_dict if suf == "" else scaling_dict_20 for key, val in result_dict.items(): sd[scaling_tcount[dataset] + " " + key] = val - elif dataset in ['mc_maze', 'mc_rtt', 'area2_bump', 'dmfc_rsg']: + elif dataset in ["mc_maze", "mc_rtt", "area2_bump", "dmfc_rsg"]: result_list.append({f"{dataset_name}_split": result_dict}) - + # put scaling data in proper split if len(scaling_dict) > 0: - result_list.append({'mc_maze_scaling_split': scaling_dict}) + result_list.append({"mc_maze_scaling_split": scaling_dict}) if len(scaling_dict_20) > 0: - result_list.append({'mc_maze_scaling_20_split': scaling_dict_20}) - + result_list.append({"mc_maze_scaling_20_split": scaling_dict_20}) + logger.info("Completed evaluation") try: @@ -152,11 +205,12 @@ def evaluate(test_annotation_file, user_submission_file): return result_list + def neg_log_likelihood(rates, spikes, zero_warning=True): """Calculates Poisson negative log likelihood given rates and spikes. formula: -log(e^(-r) / n! * r^n) = r - n*log(r) + log(n!) - + Parameters ---------- rates : np.ndarray @@ -164,35 +218,37 @@ def neg_log_likelihood(rates, spikes, zero_warning=True): spikes : np.ndarray numpy array containing true spike counts zero_warning : bool, optional - Whether to print out warning about 0 rate + Whether to print out warning about 0 rate predictions or not - + Returns ------- float Total negative log-likelihood of the data """ - assert spikes.shape == rates.shape, \ - f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}" + assert ( + spikes.shape == rates.shape + ), f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}" if np.any(np.isnan(spikes)): mask = np.isnan(spikes) rates = rates[~mask] spikes = spikes[~mask] - - assert not np.any(np.isnan(rates)), \ - "neg_log_likelihood: NaN rate predictions found" - assert np.all(rates >= 0), \ - "neg_log_likelihood: Negative rate predictions found" - if (np.any(rates == 0)): + assert not np.any(np.isnan(rates)), "neg_log_likelihood: NaN rate predictions found" + + assert np.all(rates >= 0), "neg_log_likelihood: Negative rate predictions found" + if np.any(rates == 0): if zero_warning: - logger.warning("neg_log_likelihood: Zero rate predictions found. Replacing zeros with 1e-9") + logger.warning( + "neg_log_likelihood: Zero rate predictions found. Replacing zeros with 1e-9" + ) rates[rates == 0] = 1e-9 - + result = rates - spikes * np.log(rates) + gammaln(spikes + 1.0) return np.sum(result) + def bits_per_spike(rates, spikes): """Computes bits per spike of rate predictions given spikes. Bits per spike is equal to the difference between the log-likelihoods (in base 2) @@ -205,21 +261,26 @@ def bits_per_spike(rates, spikes): 3d numpy array containing rate predictions spikes : np.ndarray 3d numpy array containing true spike counts - + Returns ------- float Bits per spike of rate predictions """ nll_model = neg_log_likelihood(rates, spikes) - nll_null = neg_log_likelihood(np.tile(np.nanmean(spikes, axis=(0,1), keepdims=True), (spikes.shape[0], spikes.shape[1], 1)), spikes, zero_warning=False) + null_rates = np.tile( + np.nanmean(spikes, axis=tuple(range(spikes.ndim - 1)), keepdims=True), + spikes.shape[:-1] + (1,), + ) + nll_null = neg_log_likelihood(null_rates, spikes, zero_warning=False) return (nll_null - nll_model) / np.nansum(spikes) / np.log(2) + def fit_and_eval_decoder( - train_rates, - train_behavior, - eval_rates, - eval_behavior, + train_rates, + train_behavior, + eval_rates, + eval_behavior, grid_search=True, ): """Fits ridge regression on train data passed @@ -242,9 +303,9 @@ def fit_and_eval_decoder( 2d array with same dimension ordering as train_behavior. Used to evaluate regressor grid_search : bool - Whether to perform a cross-validated grid search to find + Whether to perform a cross-validated grid search to find the best regularization hyperparameters. - + Returns ------- float @@ -256,8 +317,9 @@ def fit_and_eval_decoder( if np.any(np.isnan(eval_behavior)): eval_rates = eval_rates[~np.isnan(eval_behavior)[:, 0]] eval_behavior = eval_behavior[~np.isnan(eval_behavior)[:, 0]] - assert not np.any(np.isnan(train_rates)) and not np.any(np.isnan(eval_rates)), \ - "fit_and_eval_decoder: NaNs found in rate predictions within required trial times" + assert not np.any(np.isnan(train_rates)) and not np.any( + np.isnan(eval_rates) + ), "fit_and_eval_decoder: NaNs found in rate predictions within required trial times" if grid_search: decoder = GridSearchCV(Ridge(), {"alpha": np.logspace(-4, 0, 9)}) @@ -266,6 +328,7 @@ def fit_and_eval_decoder( decoder.fit(train_rates, train_behavior) return decoder.score(eval_rates, eval_behavior) + def eval_psth(psth, eval_rates, eval_cond_idx, jitter=None): """Evaluates match to PSTH across conditions Parameters @@ -281,7 +344,7 @@ def eval_psth(psth, eval_rates, eval_cond_idx, jitter=None): corresponding to conditions in `psth` jitter : np.ndarray, optional 1d array containing jitter applied to each eval trial - + Returns ------- float @@ -289,20 +352,29 @@ def eval_psth(psth, eval_rates, eval_cond_idx, jitter=None): to true PSTHs across all conditions, averaged across neurons """ - jitter_trial = lambda x: x[0] if x[1] == 0 else \ - np.vstack([np.full((x[1], x[0].shape[1]), np.nan), x[0][:-x[1]]]) if x[1] > 0 else \ - np.vstack([x[0][-x[1]:], np.full((-x[1], x[0].shape[1]), np.nan)]) + jitter_trial = ( + lambda x: x[0] + if x[1] == 0 + else np.vstack([np.full((x[1], x[0].shape[1]), np.nan), x[0][: -x[1]]]) + if x[1] > 0 + else np.vstack([x[0][-x[1] :], np.full((-x[1], x[0].shape[1]), np.nan)]) + ) if jitter is None: jitter = np.zeros(eval_rates.shape[0]).astype(int) - true_list = []; pred_list = [] + true_list = [] + pred_list = [] for i in range(len(eval_cond_idx)): if eval_cond_idx[i].size == 0: continue - pred_psth = np.mean([jitter_trial((eval_rates[idx], jitter[idx])) for idx in eval_cond_idx[i]], axis=0) + pred_psth = np.mean( + [jitter_trial((eval_rates[idx], jitter[idx])) for idx in eval_cond_idx[i]], + axis=0, + ) true_psth = psth[i, :, :][~np.isnan(psth[i, :, 0])] pred_psth = pred_psth[~np.isnan(psth[i, :, 0])] - assert not np.any(np.isnan(pred_psth)), \ - "eval_psth: NaNs found in rate predictions within required trial times" + assert not np.any( + np.isnan(pred_psth) + ), "eval_psth: NaNs found in rate predictions within required trial times" true_list.append(true_psth) pred_list.append(pred_psth) @@ -339,6 +411,7 @@ def speed_tp_correlation(eval_spikes_heldout, eval_rates, eval_behavior): # Compute neural speed during the set-go period for each trial def compute_speed(trial): return np.mean(np.linalg.norm(np.diff(trial, axis=0), axis=1)) + eval_speeds = [compute_speed(trial[mask]) for trial, mask in zip(eval_rates, masks)] eval_speeds = np.array(eval_speeds) # Compute correlation within each condition @@ -378,28 +451,28 @@ def velocity_decoding( 3d array, with dimensions trial x time x neuron, containing rate predictions for all train split trials. train_behavior : np.ndarray - 3d array, with dimensions trial x time x 2, containing x and y hand velocity + 3d array, with dimensions trial x time x 2, containing x and y hand velocity for all train split trials. train_decode_mask : np.ndarray - 2d array, with dimensions trial x n_masks, containing masks that group trials + 2d array, with dimensions trial x n_masks, containing masks that group trials with the same decoder for all train split trials. eval_rates : np.ndarray 3d array, with dimensions trial x time x neuron, containing rate predictions for all test split trials. eval_behavior : np.ndarray - 3d array, with dimensions trial x time x 2, containing x and y hand velocity + 3d array, with dimensions trial x time x 2, containing x and y hand velocity for all test split trials. eval_decode_mask : np.ndarray - 2d array, with dimensions trial x n_masks, containing masks that group trials + 2d array, with dimensions trial x n_masks, containing masks that group trials with the same decoder for all test split trials. grid_search : bool, optional - Whether to use a cross-validated grid search over the ridge regularization + Whether to use a cross-validated grid search over the ridge regularization penalty, by default True Returns ------- float - Average coefficient of determination for hand velocity decoding across masked + Average coefficient of determination for hand velocity decoding across masked groups. """ flatten3d = lambda x: x.reshape(-1, x.shape[2]) if (len(x.shape) > 2) else x diff --git a/nlb_tools/nwb_interface.py b/nlb_tools/nwb_interface.py index 3bcfe4d..be0d2b6 100644 --- a/nlb_tools/nwb_interface.py +++ b/nlb_tools/nwb_interface.py @@ -1,8 +1,8 @@ - from pynwb import NWBFile, NWBHDF5IO, TimeSeries, ProcessingModule from pynwb.core import MultiContainerInterface, NWBDataInterface from scipy.stats import mode from glob import glob +from pathlib import Path import numpy as np import pandas as pd import scipy.signal as signal @@ -19,8 +19,9 @@ class NWBDataset: """A class for loading/preprocessing data from NWB files for the NLB competition """ - def __init__(self, fpath, prefix='', split_heldout=True, skip_fields=[]): - """Initializes an NWBDataset, loading data from + + def __init__(self, fpath, prefix="", split_heldout=True, skip_fields=[]): + """Initializes an NWBDataset, loading data from the indicated file(s) Parameters @@ -32,7 +33,7 @@ def __init__(self, fpath, prefix='', split_heldout=True, skip_fields=[]): A pattern used to filter the NWB files in directory by name. By default, prefix='' loads all .nwb files in the directory. Please refer to documentation for - the `glob` module for more details: + the `glob` module for more details: https://docs.python.org/3/library/glob.html split_heldout : bool, optional Whether to load heldin units and heldout units @@ -41,12 +42,12 @@ def __init__(self, fpath, prefix='', split_heldout=True, skip_fields=[]): List of field names to skip during loading, which may be useful if memory is an issue. Field names must match the names automatically - assigned in the loading process. Spiking data + assigned in the loading process. Spiking data can not be skipped. Field names in the list that are not found in the dataset are ignored """ - if isinstance(fpath, str): + if isinstance(fpath, (str, Path)): fpath = os.path.expanduser(fpath) self.fpath = fpath self.prefix = prefix @@ -60,55 +61,91 @@ def __init__(self, fpath, prefix='', split_heldout=True, skip_fields=[]): filenames = [fpath] # If no files found if len(filenames) == 0: - raise FileNotFoundError(f"No matching files with prefix {prefix} found in directory {fpath}") + raise FileNotFoundError( + f"No matching files with prefix {prefix} found in directory {fpath}" + ) # If multiple files found elif len(filenames) > 1: - loaded = [self.load(fname, split_heldout=split_heldout, skip_fields=skip_fields) for fname in filenames] - datas, trial_infos, descriptions, bin_widths = [list(out) for out in zip(*loaded)] - assert np.all(np.array(bin_widths) == bin_widths[0]), "Bin widths of loaded datasets must be the same" + loaded = [ + self.load( + fname, split_heldout=split_heldout, skip_fields=skip_fields + ) + for fname in filenames + ] + datas, trial_infos, descriptions, bin_widths = [ + list(out) for out in zip(*loaded) + ] + assert np.all( + np.array(bin_widths) == bin_widths[0] + ), "Bin widths of loaded datasets must be the same" # Shift loaded files to stack them into continuous array def trial_shift(x, shift_ms, trial_offset): - if x.name.endswith('_time'): - return x + pd.to_timedelta(shift_ms, unit='ms') - elif x.name == 'trial_id': + if x.name.endswith("_time"): + return x + pd.to_timedelta(shift_ms, unit="ms") + elif x.name == "trial_id": return x + trial_offset else: return x + # Loop through files, shifting continuous data - past_end = datas[0].index[-1].total_seconds() + round(50 * bin_widths[0] / 1000, 4) + past_end = datas[0].index[-1].total_seconds() + round( + 50 * bin_widths[0] / 1000, 4 + ) descriptions_full = descriptions[0] tcount = len(trial_infos[0]) for i in range(1, len(datas)): block_start_ms = np.ceil(past_end * 10) * 100 - datas[i] = datas[i].shift(block_start_ms, freq='ms') - trial_infos[i] = trial_infos[i].apply(trial_shift, shift_ms=block_start_ms, trial_offset=tcount) + datas[i] = datas[i].shift(block_start_ms, freq="ms") + trial_infos[i] = trial_infos[i].apply( + trial_shift, shift_ms=block_start_ms, trial_offset=tcount + ) descriptions_full.update(descriptions[i]) - past_end = datas[i].index[-1].total_seconds() + round(50 * bin_widths[i] / 1000, 4) + past_end = datas[i].index[-1].total_seconds() + round( + 50 * bin_widths[i] / 1000, 4 + ) tcount += len(trial_infos[i]) # Stack data and reindex to continuous - self.data = pd.concat(datas, axis=0, join='outer') - self.trial_info = pd.concat(trial_infos, axis=0, join='outer').reset_index(drop=True) + self.data = pd.concat(datas, axis=0, join="outer") + self.trial_info = pd.concat( + trial_infos, axis=0, join="outer" + ).reset_index(drop=True) self.descriptions = descriptions_full self.bin_width = bin_widths[0] - new_index = pd.to_timedelta((np.arange(round(self.data.index[-1].total_seconds() * 1000 / self.bin_width) + 1) * self.bin_width).round(4), unit='ms') + new_index = pd.to_timedelta( + ( + np.arange( + round( + self.data.index[-1].total_seconds() + * 1000 + / self.bin_width + ) + + 1 + ) + * self.bin_width + ).round(4), + unit="ms", + ) self.data = self.data.reindex(new_index) - self.data.index.name = 'clock_time' + self.data.index.name = "clock_time" # If single file found else: - data, trial_info, descriptions, bin_width = self.load(filenames[0], split_heldout=split_heldout, skip_fields=skip_fields) + data, trial_info, descriptions, bin_width = self.load( + filenames[0], split_heldout=split_heldout, skip_fields=skip_fields + ) self.data = data self.trial_info = trial_info self.descriptions = descriptions self.bin_width = bin_width else: # case of nwbfile object - data, trial_info, descriptions, bin_width = self.load(fpath, split_heldout=split_heldout, skip_fields=skip_fields) + data, trial_info, descriptions, bin_width = self.load( + fpath, split_heldout=split_heldout, skip_fields=skip_fields + ) self.data = data self.trial_info = trial_info self.descriptions = descriptions self.bin_width = bin_width - def load(self, fpath, split_heldout=True, skip_fields=[]): """Loads data from an NWB file into two dataframes, one for trial info and one for time-varying data @@ -124,7 +161,7 @@ def load(self, fpath, split_heldout=True, skip_fields=[]): List of field names to skip during loading, which may be useful if memory is an issue. Field names must match the names automatically - assigned in the loading process. Spiking data + assigned in the loading process. Spiking data can not be skipped. Field names in the list that are not found in the dataset are ignored @@ -139,9 +176,9 @@ def load(self, fpath, split_heldout=True, skip_fields=[]): """ logger.info(f"Loading {fpath}") - if isinstance(fpath, str): + if isinstance(fpath, (str, Path)): # Open NWB file - io = NWBHDF5IO(fpath, 'r') + io = NWBHDF5IO(fpath, "r") nwbfile = io.read() else: nwbfile = fpath @@ -150,7 +187,8 @@ def load(self, fpath, split_heldout=True, skip_fields=[]): trial_info = ( nwbfile.trials.to_dataframe() .reset_index() - .rename({'id': 'trial_id', 'stop_time': 'end_time'}, axis=1)) + .rename({"id": "trial_id", "stop_time": "end_time"}, axis=1) + ) units = nwbfile.units.to_dataframe() # Load descriptions of trial info fields @@ -168,8 +206,14 @@ def make_df(ts): index = ts.timestamps[()] else: index = np.arange(ts.data.shape[0]) / ts.rate + ts.starting_time - columns = ts.comments.split('[')[-1].split(']')[0].split(',') if 'columns=' in ts.comments else None - df = pd.DataFrame(ts.data[()], index=pd.to_timedelta(index, unit='s'), columns=columns) + columns = ( + ts.comments.split("[")[-1].split("]")[0].split(",") + if "columns=" in ts.comments + else None + ) + df = pd.DataFrame( + ts.data[()], index=pd.to_timedelta(index, unit="s"), columns=columns + ) return df def find_timeseries(nwbobj): @@ -199,15 +243,18 @@ def find_timeseries(nwbobj): # Calculate data index start_time = 0.0 - bin_width = 1 # in ms, this will be the case for all provided datasets - rate = round(1000. / bin_width, 2) # in Hz + bin_width = 1 # in ms, this will be the case for all provided datasets + rate = round(1000.0 / bin_width, 2) # in Hz # Use obs_intervals, or last trial to determine data end - end_time = round(max(units.obs_intervals.apply(lambda x: x[-1][-1])) * rate) * bin_width - if (end_time < trial_info['end_time'].iloc[-1]): - print("obs_interval ends before trial end") # TO REMOVE - end_time = round(trial_info['end_time'].iloc[-1] * rate) * bin_width + end_time = ( + round(max(units.obs_intervals.apply(lambda x: x[-1][-1])) * rate) + * bin_width + ) + if end_time < trial_info["end_time"].iloc[-1]: + print("obs_interval ends before trial end") # TO REMOVE + end_time = round(trial_info["end_time"].iloc[-1] * rate) * bin_width timestamps = (np.arange(start_time, end_time, bin_width) / 1000).round(6) - timestamps_td = pd.to_timedelta(timestamps, unit='s') + timestamps_td = pd.to_timedelta(timestamps, unit="s") # Check that all timeseries match with calculated timestamps for key, val in list(data_dict.items()): @@ -219,73 +266,89 @@ def make_mask(obs_intervals): """Creates boolean mask to indicate when spiking data is not in obs_intervals""" mask = np.full(timestamps.shape, True) for start, end in obs_intervals: - start_idx = np.ceil(round((start - timestamps[0]) * rate, 6)).astype(int) + start_idx = np.ceil(round((start - timestamps[0]) * rate, 6)).astype( + int + ) end_idx = np.floor(round((end - timestamps[0]) * rate, 6)).astype(int) mask[start_idx:end_idx] = False return mask - + # Prepare variables for spike binning - masks = [(~units.heldout).to_numpy(), units.heldout.to_numpy()] if split_heldout else [np.full(len(units), True)] + masks = ( + [(~units.heldout).to_numpy(), units.heldout.to_numpy()] + if split_heldout + else [np.full(len(units), True)] + ) - for mask, name in zip(masks, ['spikes', 'heldout_spikes']): + for mask, name in zip(masks, ["spikes", "heldout_spikes"]): # Check if there are any units if not np.any(mask): continue - + # Allocate array to fill with spikes - spike_arr = np.full((len(timestamps), np.sum(mask)), 0.0, dtype='float16') + spike_arr = np.full((len(timestamps), np.sum(mask)), 0.0, dtype="float16") # Bin spikes using decimal truncation and np.unique - faster than np.histogram with same results for idx, (_, unit) in enumerate(units[mask].iterrows()): - spike_idx, spike_cnt = np.unique(((unit.spike_times - timestamps[0]) * rate).round(6).astype(int), return_counts=True) + spike_idx, spike_cnt = np.unique( + ((unit.spike_times - timestamps[0]) * rate).round(6).astype(int), + return_counts=True, + ) spike_arr[spike_idx, idx] = spike_cnt # Replace invalid intervals in spike recordings with NaNs - if 'obs_intervals' in units.columns: + if "obs_intervals" in units.columns: neur_mask = make_mask(units[mask].iloc[0].obs_intervals) if np.any(spike_arr[neur_mask]): logger.warning("Spikes found outside of observed interval.") spike_arr[neur_mask] = np.nan # Create DataFrames with spike arrays - data_dict[name] = pd.DataFrame(spike_arr, index=timestamps_td, columns=units[mask].index).astype('float16', copy=False) + data_dict[name] = pd.DataFrame( + spike_arr, index=timestamps_td, columns=units[mask].index + ).astype("float16", copy=False) # Create MultiIndex column names data_list = [] for key, val in data_dict.items(): chan_names = None if type(val.columns) == pd.RangeIndex else val.columns - val.columns = self._make_midx(key, chan_names=chan_names, num_channels=val.shape[1]) + val.columns = self._make_midx( + key, chan_names=chan_names, num_channels=val.shape[1] + ) data_list.append(val) - + # Assign time-varying data to `self.data` data = pd.concat(data_list, axis=1) - data.index.name = 'clock_time' + data.index.name = "clock_time" data.sort_index(axis=1, inplace=True) # Convert time fields in trial info to timedelta # and assign to `self.trial_info` def to_td(x): - if x.name.endswith('_time'): - return pd.to_timedelta(x, unit='s') + if x.name.endswith("_time"): + return pd.to_timedelta(x, unit="s") else: return x + trial_info = trial_info.apply(to_td, axis=0) if isinstance(fpath, str): io.close() return data, trial_info, descriptions, bin_width - - def make_trial_data(self, - start_field='start_time', - end_field='end_time', - align_field=None, - align_range=(None, None), - margin=0, - ignored_trials=None, - allow_overlap=False, - allow_nans=False): - """Makes a DataFrame of trialized data based on + + def make_trial_data( + self, + start_field="start_time", + end_field="end_time", + align_field=None, + align_range=(None, None), + margin=0, + ignored_trials=None, + allow_overlap=False, + allow_nans=False, + ): + """Makes a DataFrame of trialized data based on an alignment field Parameters @@ -301,17 +364,17 @@ def make_trial_data(self, by default None, which does not align trials and instead takes them in their entirety align_range : tuple of int, optional - The offsets to add to the alignment field to - calculate the alignment window, by default (None, None) + The offsets to add to the alignment field to + calculate the alignment window, by default (None, None) uses `trial_start` and `trial_end` margin : int, optional - The number of ms of extra data to include on either end of - each trial, labeled with the `margin` column for easy + The number of ms of extra data to include on either end of + each trial, labeled with the `margin` column for easy removal. Margins are useful for decoding and smoothing ignored_trials : pd.Series or np.ndarray, optional - A boolean pd.Series or np.ndarray of the same length - as trial_info with True for the trials to ignore, by - default None ignores no trials. This is useful for + A boolean pd.Series or np.ndarray of the same length + as trial_info with True for the trials to ignore, by + default None ignores no trials. This is useful for rejecting trials outside of the alignment process allow_overlap : bool, optional Whether to allow overlap between trials, by default False @@ -319,183 +382,217 @@ def make_trial_data(self, allow_nans : bool, optional Whether to allow NaNs within trials, by default False drops all timestamps containing NaNs in any column - + Returns ------- pd.DataFrame A DataFrame containing trialized data. It has the same - fields as the continuous `self.data` DataFrame, but + fields as the continuous `self.data` DataFrame, but adds `trial_id`, `trial_time`, and `align_time`. It also resets the index so `clock_time` is a column rather than - an index. This DataFrame can be pivoted to plot its - various fields across trials, aligned relative to + an index. This DataFrame can be pivoted to plot its + various fields across trials, aligned relative to `align_time`, `trial_time`, or `clock_time` """ # Allow rejection of trials by passing a boolean series trial_info = self.trial_info.copy() - trial_info['next_start'] = trial_info['start_time'].shift(-1) + trial_info["next_start"] = trial_info["start_time"].shift(-1) if ignored_trials is not None: trial_info = trial_info.loc[~ignored_trials] if len(trial_info) == 0: logger.warning("All trials ignored. No trial data made") return - + # Find alignment points - bin_width = pd.to_timedelta(self.bin_width, unit='ms') + bin_width = pd.to_timedelta(self.bin_width, unit="ms") trial_start = trial_info[start_field] trial_end = trial_info[end_field] - next_start = trial_info['next_start'] + next_start = trial_info["next_start"] if align_field is not None: align_left = align_right = trial_info[align_field] else: - align_field = f'{start_field} and {end_field}' # for logging + align_field = f"{start_field} and {end_field}" # for logging align_left = trial_start align_right = trial_end - + # Find start and end points based on the alignment range - start_offset, end_offset = pd.to_timedelta(align_range, unit='ms') + start_offset, end_offset = pd.to_timedelta(align_range, unit="ms") if not pd.isnull(start_offset) and not pd.isnull(end_offset): if not ((end_offset - start_offset) / bin_width).is_integer(): # Round align offsets if alignment range is not multiple of bin width end_offset = start_offset + (end_offset - start_offset).round(bin_width) align_range = ( int(round(start_offset.total_seconds() * 1000)), - int(round(end_offset.total_seconds() * 1000)) + int(round(end_offset.total_seconds() * 1000)), + ) + logger.warning( + "Alignment window not integer multiple of bin width. " + f"Rounded to {align_range}" ) - logger.warning('Alignment window not integer multiple of bin width. ' - f'Rounded to {align_range}') if pd.isnull(start_offset): align_start = trial_start else: align_start = align_left + start_offset if pd.isnull(end_offset): # Subtract small interval to prevent inclusive timedelta .loc indexing - align_end = trial_end - pd.to_timedelta(1, unit='us') + align_end = trial_end - pd.to_timedelta(1, unit="us") else: - align_end = align_right + end_offset - pd.to_timedelta(1, unit='us') + align_end = align_right + end_offset - pd.to_timedelta(1, unit="us") # Add margins to either end of the data - margin_delta = pd.to_timedelta(margin, unit='ms') + margin_delta = pd.to_timedelta(margin, unit="ms") margin_start = align_start - margin_delta margin_end = align_end + margin_delta - trial_ids = trial_info['trial_id'] + trial_ids = trial_info["trial_id"] # Store the alignment data in a dataframe - align_data = pd.DataFrame({ - 'trial_id': trial_ids, - 'margin_start': margin_start, - 'margin_end': margin_end, - 'align_start': align_start, - 'align_end': align_end, - 'trial_start': trial_start, - 'align_left': align_left}).dropna() + align_data = pd.DataFrame( + { + "trial_id": trial_ids, + "margin_start": margin_start, + "margin_end": margin_end, + "align_start": align_start, + "align_end": align_end, + "trial_start": trial_start, + "align_left": align_left, + } + ).dropna() # Bound the end by the next trial / alignment start - align_data['end_bound'] = ( - pd.concat([next_start, align_start], axis=1) - .min(axis=1) - .shift(-1)) + align_data["end_bound"] = ( + pd.concat([next_start, align_start], axis=1).min(axis=1).shift(-1) + ) trial_dfs = [] num_overlap_trials = 0 + def make_trial_df(args): idx, row = args # Handle overlap with the start of the next trial endpoint = row.margin_end trial_id = row.trial_id overlap = False - if not pd.isnull(row.end_bound) and \ - row.align_end > row.end_bound: + if not pd.isnull(row.end_bound) and row.align_end > row.end_bound: overlap = True if not allow_overlap: # Allow overlapping margins, but not aligned data - endpoint = row.end_bound + margin_delta - pd.to_timedelta(1, unit='us') + endpoint = ( + row.end_bound + margin_delta - pd.to_timedelta(1, unit="us") + ) # Take a slice of the continuous data - trial_idx = pd.Series(self.data.index[self.data.index.slice_indexer(row.margin_start, endpoint)]) + trial_idx = pd.Series( + self.data.index[ + self.data.index.slice_indexer(row.margin_start, endpoint) + ] + ) # Add trial identifiers - trial_df = pd.DataFrame({ - ('trial_id', ''): np.repeat(trial_id, len(trial_idx)), - ('trial_time', ''): (trial_idx - row.trial_start.ceil(bin_width)), - ('align_time', ''): (trial_idx - row.align_left.ceil(bin_width)), - ('margin', ''): ((trial_idx < row.align_start) | (row.align_end < trial_idx))}) + trial_df = pd.DataFrame( + { + ("trial_id", ""): np.repeat(trial_id, len(trial_idx)), + ("trial_time", ""): (trial_idx - row.trial_start.ceil(bin_width)), + ("align_time", ""): (trial_idx - row.align_left.ceil(bin_width)), + ("margin", ""): ( + (trial_idx < row.align_start) | (row.align_end < trial_idx) + ), + } + ) trial_df.index = trial_idx return overlap, trial_df - overlaps, trial_dfs = zip(*[make_trial_df(args) for args in align_data.iterrows()]) + + overlaps, trial_dfs = zip( + *[make_trial_df(args) for args in align_data.iterrows()] + ) num_overlap_trials = sum(overlaps) # Summarize alignment - logger.info(f'Aligned {len(trial_dfs)} trials to ' - f'{align_field} with offset of {align_range} ms ' - f'and margin of {margin}.') + logger.info( + f"Aligned {len(trial_dfs)} trials to " + f"{align_field} with offset of {align_range} ms " + f"and margin of {margin}." + ) # Report any overlapping trials to the user. if num_overlap_trials > 0: if allow_overlap: - logger.warning( - f'Allowed {num_overlap_trials} overlapping trials.') + logger.warning(f"Allowed {num_overlap_trials} overlapping trials.") else: logger.warning( - f'Shortened {num_overlap_trials} trials to prevent overlap.') + f"Shortened {num_overlap_trials} trials to prevent overlap." + ) # Combine all trials into one DataFrame trial_data = pd.concat(trial_dfs) trial_data.reset_index(inplace=True) - trial_data = trial_data.merge(self.data, how='left', left_on=[('clock_time', '')], right_index=True) + trial_data = trial_data.merge( + self.data, how="left", left_on=[("clock_time", "")], right_index=True + ) # Sanity check to make sure there are no duplicated `clock_time`'s if not allow_overlap: # Duplicated points in the margins are allowed td_nonmargin = trial_data[~trial_data.margin] - assert td_nonmargin.clock_time.duplicated().sum() == 0, \ - 'Duplicated points still found. Double-check overlap code.' + assert ( + td_nonmargin.clock_time.duplicated().sum() == 0 + ), "Duplicated points still found. Double-check overlap code." # Make sure NaN's caused by adding trialized data to self.data are ignored nans_found = trial_data.isnull().sum().max() if nans_found > 0: pct_nan = (nans_found / len(trial_data)) * 100 if allow_nans: - logger.warning(f'NaNs found in {pct_nan:.2f}% of `trial_data`.') + logger.warning(f"NaNs found in {pct_nan:.2f}% of `trial_data`.") else: - logger.warning(f'NaNs found in `self.data`. Dropping {pct_nan:.2f}% ' - 'of points to remove NaNs from `trial_data`.') + logger.warning( + f"NaNs found in `self.data`. Dropping {pct_nan:.2f}% " + "of points to remove NaNs from `trial_data`." + ) trial_data = trial_data.dropna() trial_data.sort_index(axis=1, inplace=True) return trial_data def resample(self, target_bin): - """Rebins spikes and performs antialiasing + downsampling on + """Rebins spikes and performs antialiasing + downsampling on continuous signals Parameters ---------- target_bin : int - The target bin size in milliseconds. Note that it must be an + The target bin size in milliseconds. Note that it must be an integer multiple of self.bin_width """ - logger.info(f'Resampling data to {target_bin} ms.') + logger.info(f"Resampling data to {target_bin} ms.") # Check that resample_factor is an integer if target_bin == self.bin_width: - logger.warning(f'Dataset already at {target_bin} ms resolution, skipping resampling...') + logger.warning( + f"Dataset already at {target_bin} ms resolution, skipping resampling..." + ) return - assert target_bin % self.bin_width == 0, \ - 'target_bin must be an integer multiple of bin_width.' + assert ( + target_bin % self.bin_width == 0 + ), "target_bin must be an integer multiple of bin_width." resample_factor = int(round(target_bin / self.bin_width)) - + # Resample data based on signal type cols = self.data.columns data_list = [] for signal_type in cols.get_level_values(0).unique(): - if 'spikes' in signal_type: + if "spikes" in signal_type: # Rebin spikes, preserving original nan locations arr = self.data[signal_type].to_numpy() dtype = self.data[signal_type].dtypes.iloc[0] nan_mask = np.isnan(arr[::resample_factor]) if arr.shape[0] % resample_factor != 0: - extra = arr[-(arr.shape[0] % resample_factor):] - arr = arr[:-(arr.shape[0] % resample_factor)] + extra = arr[-(arr.shape[0] % resample_factor) :] + arr = arr[: -(arr.shape[0] % resample_factor)] else: extra = None - arr = np.nan_to_num(arr, copy=False).reshape((arr.shape[0] // resample_factor, resample_factor, -1)).sum(axis=1) + arr = ( + np.nan_to_num(arr, copy=False) + .reshape((arr.shape[0] // resample_factor, resample_factor, -1)) + .sum(axis=1) + ) if extra is not None: arr = np.vstack([arr, np.nan_to_num(extra, copy=False).sum(axis=0)]) arr[nan_mask] = np.nan - resamp = pd.DataFrame(arr, index=self.data.index[::resample_factor], dtype=dtype) - elif signal_type == 'target_pos': + resamp = pd.DataFrame( + arr, index=self.data.index[::resample_factor], dtype=dtype + ) + elif signal_type == "target_pos": # Resample target pos for MC_RTT resamp = self.data[signal_type].iloc[::resample_factor] else: @@ -503,42 +600,52 @@ def resample(self, target_bin): dtype = self.data[signal_type].dtypes.iloc[0] nan_mask = self.data[signal_type].iloc[::resample_factor].isna() if np.any(self.data[signal_type].isna()): - self.data[signal_type] = self.data[signal_type].apply(lambda x: x.interpolate(limit_direction='both')) + self.data[signal_type] = self.data[signal_type].apply( + lambda x: x.interpolate(limit_direction="both") + ) decimated_df = signal.decimate( - self.data[signal_type], resample_factor, axis=0, n=500, ftype='fir') + self.data[signal_type], resample_factor, axis=0, n=500, ftype="fir" + ) decimated_df[nan_mask] = np.nan - resamp = pd.DataFrame(decimated_df, index=self.data.index[::resample_factor], dtype=dtype) - resamp.columns = pd.MultiIndex.from_product([[signal_type], self.data[signal_type].columns], names=('signal_type', 'channel')) + resamp = pd.DataFrame( + decimated_df, index=self.data.index[::resample_factor], dtype=dtype + ) + resamp.columns = pd.MultiIndex.from_product( + [[signal_type], self.data[signal_type].columns], + names=("signal_type", "channel"), + ) data_list.append(resamp) # Replace old data self.data = pd.concat(data_list, axis=1) - self.data.index.freq = f'{target_bin}ms' + self.data.index.freq = f"{target_bin}ms" self.bin_width = target_bin - def smooth_spk(self, - gauss_width, - signal_type=None, - name=None, - overwrite=False, - ignore_nans=False, - parallelized=True, - dtype="float64"): + def smooth_spk( + self, + gauss_width, + signal_type=None, + name=None, + overwrite=False, + ignore_nans=False, + parallelized=True, + dtype="float64", + ): """Applies Gaussian smoothing to the data. Most often applied to spikes - + Parameters ---------- gauss_width : int The standard deviation of the Gaussian to use for smoothing, in ms signal_type : str or list of str, optional - The group of signals to smooth, by default + The group of signals to smooth, by default None, which smooths 'spikes' and 'heldout_spikes' if present in the DataFrame name : str, optional - The name to use for the smoothed data when adding + The name to use for the smoothed data when adding it back to the DataFrame, by default None. If - provided, the new signal_type name will be + provided, the new signal_type name will be the original name + '_' + `name`. Must be provided if overwrite is False overwrite : bool, optional @@ -550,7 +657,7 @@ def smooth_spk(self, into valid data during convolution, but ignoring NaNs is much slower parallelized : bool, optional - Whether to parallelize the smoothing operation + Whether to parallelize the smoothing operation with multiprocessing.Pool.map(). This may cause issues on certain systems, so it can be disabled dtype : str or dtype @@ -559,15 +666,19 @@ def smooth_spk(self, By default 'float64'. Only other float dtypes are recommended """ - assert name or overwrite, \ - ('You must either provide a name for the smoothed ' - 'data or specify to overwrite the existing data.') - + assert name or overwrite, ( + "You must either provide a name for the smoothed " + "data or specify to overwrite the existing data." + ) + if signal_type is None: - signal_type = [field for field in ['spikes', 'heldout_spikes'] if field in self.data.columns] + signal_type = [ + field + for field in ["spikes", "heldout_spikes"] + if field in self.data.columns + ] - logger.info(f'Smoothing {signal_type} with a ' - f'{gauss_width} ms Gaussian.') + logger.info(f"Smoothing {signal_type} with a " f"{gauss_width} ms Gaussian.") # Compute Gauss window and std with respect to bins gauss_bin_std = gauss_width / self.bin_width @@ -575,18 +686,22 @@ def smooth_spk(self, win_len = int(6 * gauss_bin_std) # Create Gaussian kernel window = signal.gaussian(win_len, gauss_bin_std, sym=True) - window /= np.sum(window) + window /= np.sum(window) # Extract spiking data spike_vals = self.data[signal_type].to_numpy() - + # Parallelized implementation for smoothing data if parallelized: - spike_vals_list = [spike_vals[:,i] for i in range(spike_vals.shape[1])] + spike_vals_list = [spike_vals[:, i] for i in range(spike_vals.shape[1])] y_list = _poolmap( - smooth_column, itertools.product(spike_vals_list, [window], [ignore_nans], [dtype])) + smooth_column, + itertools.product(spike_vals_list, [window], [ignore_nans], [dtype]), + ) smoothed_spikes = np.vstack(y_list).T else: - smoothed_spikes = np.apply_along_axis(lambda x: smooth_column((x, window, ignore_nans, dtype)), 0, spike_vals) + smoothed_spikes = np.apply_along_axis( + lambda x: smooth_column((x, window, ignore_nans, dtype)), 0, spike_vals + ) # Create list of column names col_names = [] @@ -597,13 +712,17 @@ def smooth_spk(self, if overwrite: smoothed_name = st else: - smoothed_name = st + '_' + name - col_names += list(zip([smoothed_name]*len(columns), columns)) + smoothed_name = st + "_" + name + col_names += list(zip([smoothed_name] * len(columns), columns)) # Write data to DataFrame if overwrite: self.data.drop(col_names, axis=1, inplace=True) - smoothed_df = pd.DataFrame(smoothed_spikes, index=self.data.index, columns=pd.MultiIndex.from_tuples(col_names)) + smoothed_df = pd.DataFrame( + smoothed_spikes, + index=self.data.index, + columns=pd.MultiIndex.from_tuples(col_names), + ) self.data = pd.concat([self.data, smoothed_df], axis=1) self.data.sort_index(axis=1, inplace=True) # deleting and concatenating new data is much faster than overwriting, but less memory efficient @@ -618,28 +737,27 @@ def smooth_spk(self, def add_continuous_data(self, cts_data, signal_type, chan_names=None): """Adds a continuous data field to the main DataFrame - + Parameters ---------- cts_data : np.ndarray - A numpy array whose first dimension matches the DataFrame + A numpy array whose first dimension matches the DataFrame at self.data signal_name : str The label for this group of signals chan_names : list of str, optional The channel names for this data """ - logger.info(f'Adding continuous {signal_type} to the main DataFrame.') + logger.info(f"Adding continuous {signal_type} to the main DataFrame.") # Make MultiIndex columns midx = self._make_midx(signal_type, chan_names, cts_data.shape[1]) # Build the DataFrame and attach it to the current dataframe - new_data = pd.DataFrame( - cts_data, index=self.data.index, columns=midx) + new_data = pd.DataFrame(cts_data, index=self.data.index, columns=midx) self.data = pd.concat([self.data, new_data], axis=1) def add_trialized_data(self, trial_data, signal_type, chan_names=None): """Adds a trialized data field to the main DataFrame - + Parameters ---------- trial_data : pd.DataFrame @@ -650,13 +768,13 @@ def add_trialized_data(self, trial_data, signal_type, chan_names=None): chan_names : list of str, optional The channel names for the data when added """ - logger.info(f'Adding trialized {signal_type} to the main DataFrame') - new_data = trial_data[['clock_time', signal_type]].set_index('clock_time') + logger.info(f"Adding trialized {signal_type} to the main DataFrame") + new_data = trial_data[["clock_time", signal_type]].set_index("clock_time") self.data = pd.concat([self.data, new_data], axis=1) - + def _make_midx(self, signal_type, chan_names=None, num_channels=None): """Creates a pd.MultiIndex for a given signal_type - + Parameters ---------- signal_type : str @@ -671,26 +789,31 @@ def _make_midx(self, signal_type, chan_names=None, num_channels=None): `chan_names` is not provided """ if chan_names is None: - if 'rates' in signal_type: + if "rates" in signal_type: # If merging rates, use the same names as the spikes chan_names = self.data.spikes.columns else: # Otherwise, generate names for the channels - assert num_channels is not None, "`num_channels` must be provided if `chan_names` is not provided" - chan_names = [f'{i:04d}' for i in range(num_channels)] + assert ( + num_channels is not None + ), "`num_channels` must be provided if `chan_names` is not provided" + chan_names = [f"{i:04d}" for i in range(num_channels)] # Create the MultiIndex for this data midx = pd.MultiIndex.from_product( - [[signal_type], chan_names], names=('signal_type', 'channel')) + [[signal_type], chan_names], names=("signal_type", "channel") + ) return midx - def calculate_onset(self, - field_name, - onset_threshold, - peak_prominence=0.1, - peak_distance_s=0.1, - multipeak_threshold=0.2): - """Calculates onset for a given field by finding - peaks and threshold crossings. Developed for + def calculate_onset( + self, + field_name, + onset_threshold, + peak_prominence=0.1, + peak_distance_s=0.1, + multipeak_threshold=0.2, + ): + """Calculates onset for a given field by finding + peaks and threshold crossings. Developed for speed onset calculation Parameters @@ -699,17 +822,17 @@ def calculate_onset(self, The field to use for onset calculation, used with recursive getattr on self.data onset_threshold : float - The threshold for onset as a percentage of the + The threshold for onset as a percentage of the peak height peak_prominence : float, optional - Minimum prominence of peaks. Passed to + Minimum prominence of peaks. Passed to `scipy.signal.find_peaks`, by default 0.1 peak_distance_s : float, optional - Minimum distance between peaks. Passed to + Minimum distance between peaks. Passed to `scipy.signal.find_peaks`, by default 0.1 multipeak_threshold : float, optional - Subsequent peaks within a trial must be no - larger than this percentage of the first peak, + Subsequent peaks within a trial must be no + larger than this percentage of the first peak, otherwise the onset calculation fails, by default 0.2 Returns @@ -719,62 +842,69 @@ def calculate_onset(self, """ import functools + def rgetattr(obj, attr, *args): - """A recursive drop-in replacement for getattr, + """A recursive drop-in replacement for getattr, which also handles dotted attr strings """ + def _getattr(obj, attr): return getattr(obj, attr, *args) - return functools.reduce(_getattr, [obj] + attr.split('.')) - - logger.info(f'Calculating {field_name} onset.') + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + logger.info(f"Calculating {field_name} onset.") sig = rgetattr(self.data, field_name) # Find peaks peaks, properties = signal.find_peaks( sig, - prominence=peak_prominence, - distance=peak_distance_s / (self.bin_width / 1000.0)) + prominence=peak_prominence, + distance=peak_distance_s / (self.bin_width / 1000.0), + ) peak_times = pd.Series(self.data.index[peaks]) # Find the onset for each trial onset, onset_index = [], [] for index, row in self.trial_info.iterrows(): - trial_start, trial_end = row['start_time'], row['end_time'] + trial_start, trial_end = row["start_time"], row["end_time"] # Find the peaks within the trial boundaries trial_peaks = peak_times[ - (trial_start < peak_times) & (peak_times < trial_end)] + (trial_start < peak_times) & (peak_times < trial_end) + ] peak_signals = sig.loc[trial_peaks] # Remove trials with multiple larger peaks if multipeak_threshold is not None and len(trial_peaks) > 1: # Make sure the first peak is relatively large - if peak_signals[0]*multipeak_threshold < peak_signals[1:].max(): + if peak_signals[0] * multipeak_threshold < peak_signals[1:].max(): continue elif len(trial_peaks) == 0: # If no peaks are found for this trial, skip it continue # Find the point just before speed crosses the threshold signal_threshold = onset_threshold * peak_signals[0] - under_threshold = sig[trial_start:trial_peaks.iloc[0]] < signal_threshold + under_threshold = sig[trial_start : trial_peaks.iloc[0]] < signal_threshold if under_threshold.sum() > 0: onset.append(under_threshold[::-1].idxmax()) onset_index.append(index) # Add the movement onset for each trial to the DataFrame - onset_name = field_name.split('.')[-1] + '_onset' - logger.info(f'`{onset_name}` field created in trial_info.') + onset_name = field_name.split(".")[-1] + "_onset" + logger.info(f"`{onset_name}` field created in trial_info.") self.trial_info[onset_name] = pd.Series(onset, index=onset_index) - + return peak_times -''' Multiprocessing Functions ''' +""" Multiprocessing Functions """ + + def smooth_column(args): """Low-level helper function for smoothing single column - + Parameters ---------- args : tuple - Tuple containing data to smooth in 1d array, - smoothing kernel in 1d array, whether to + Tuple containing data to smooth in 1d array, + smoothing kernel in 1d array, whether to ignore nans, and data dtype Returns @@ -792,18 +922,25 @@ def smooth_column(args): def rectify(arr): arr[arr < 0] = 0 return arr + # smooth only the not-NaN data - seqs = [seq if np.any(np.isnan(seq)) else rectify(signal.convolve(seq, window, 'same')) for seq in seqs] + seqs = [ + seq + if np.any(np.isnan(seq)) + else rectify(signal.convolve(seq, window, "same")) + for seq in seqs + ] # concatenate to single array y = np.concatenate(seqs) else: - y = signal.convolve(x.astype(dtype), window, 'same') + y = signal.convolve(x.astype(dtype), window, "same") return y + def _poolmap(f, X, nprocs=multiprocessing.cpu_count(), chunksize=None): """Wrapper for multiprocessing.Pool.map() to parallelize execution of function f on elements of iterable X - + Parameters ---------- f : function @@ -814,8 +951,8 @@ def _poolmap(f, X, nprocs=multiprocessing.cpu_count(), chunksize=None): Maximum number of parallel processes, by default the number of CPUs chunksize : int, optional - Chunk size to fetch from iterable for - each process. Refer to multiprocessing + Chunk size to fetch from iterable for + each process. Refer to multiprocessing documentation for more information """ with multiprocessing.Pool(processes=nprocs) as pool: diff --git a/pyproject.toml b/pyproject.toml index 0b5c28c..4aff98b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,39 +1,3 @@ -[tool.poetry] -name = "nlb_tools" -version = "0.0.2" -license = "MIT" -description = "Python tools for participating in Neural Latents Benchmark '21" -authors = [ - "Felix Pei " -] -packages = [ - {include = "nlb_tools"} -] - -readme = "README.md" -homepage = "https://github.com/neurallatents/nlb_tools" -classifiers = [ - "Intended Audience :: Science/Research", - "Operating System :: Microsoft :: Windows", - "Operating System :: MacOS", - "Operating System :: Unix", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", -] - -[tool.poetry.dependencies] -python = "^3.7" -pandas = ">=1.0.0,<= 1.3.4" -scipy = "*" -numpy = "*" -scikit-learn = "*" -h5py = ">=2.9,<4" -pynwb = "*" - -[tool.poetry.dev-dependencies] -pytest = "*" -dandi = "*" - [build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" \ No newline at end of file +requires = ["setuptools>=61.0.0", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..228a0e5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +pandas>=1.0.0,<=1.3.4 +scipy +numpy +scikit-learn +h5py<4,>=2.9 +pynwb diff --git a/setup.py b/setup.py index 5dca5d7..96e984e 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,35 @@ from setuptools import setup, find_packages +with open("README.md") as f: + long_description = f.read() + +with open("requirements.txt") as f: + requirements = f.readlines() + setup( - name='nlb_tools', - version='0.0.2', + name="nlb_tools", + version="0.0.3", + description="Python tools for participating in Neural Latents Benchmark '21", packages=find_packages(), - install_requires=[ - 'pandas>=1.0.0,<=1.3.4', - 'scipy', - 'numpy', - 'scikit-learn', - 'h5py<4,>=2.9', - 'pynwb', + install_requires=requirements, + author="Felix Pei", + classifiers=[ + "Intended Audience :: Science/Research", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS", + "Operating System :: Unix", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], - author="Neural Latents", + extras_require={ + "dev": ["pytest", "dandi"], + }, + license="MIT", + long_description=long_description, + long_description_content_type="text/markdown", + python_requires=">=3.7", + setup_requires=["setuptools>=61.0.0", "wheel"], + url="https://github.com/neurallatents/nlb_tools", ) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 026c634..47daae4 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1 +1,179 @@ -# TODO: make tests of main evaluation functions \ No newline at end of file +import pytest +import numpy as np +from scipy.special import gammaln + +from nlb_tools.evaluation import ( + evaluate, + neg_log_likelihood, + bits_per_spike, + fit_and_eval_decoder, + eval_psth, + speed_tp_correlation, + velocity_decoding, +) + + +# -- NLL and bits/spike ---------- + + +def test_neg_log_likelihood(): + """Test that NLL computation is correct""" + # randomized test + for _ in range(20): + spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + + expected_nll = np.sum(rates - spikes * np.log(rates) + gammaln(spikes + 1.0)) + actual_nll = neg_log_likelihood(rates, spikes) + assert np.isclose(expected_nll, actual_nll) + + +def test_neg_log_likelihood_mismatched_shapes(): + """Test that NLL computation fails when shapes don't match""" + # randomized test + spikes = np.random.randint(low=0, high=5, size=(10, 100, 8)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + + with pytest.raises(AssertionError): + neg_log_likelihood(rates, spikes) + + +def test_neg_log_likelihood_negative_rates(): + """Test that NLL computation fials when rates are negative""" + # randomized test + spikes = np.random.randint(low=0, high=5, size=(10, 100, 8)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + rates -= np.min(rates) + 5 # guarantee negative rates + + with pytest.raises(AssertionError): + neg_log_likelihood(rates, spikes) + + +def test_neg_log_likelihood_drop_nans(): + """Test that NLL computation is correct when there are nans in either rates or spikes""" + # randomized test + for _ in range(20): + spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + mask = np.random.rand(10, 100, 10) > 0.9 + spikes[mask] = np.nan + if np.random.rand() > 0.5: # rates does not have to have nans + rates[mask] = np.nan + + expected_nll = np.sum( + rates[~mask] + - spikes[~mask] * np.log(rates[~mask]) + + gammaln(spikes[~mask] + 1.0) + ) + actual_nll = neg_log_likelihood(rates, spikes) + assert np.isclose(expected_nll, actual_nll) + + +def test_neg_log_likelihood_mismatched_nans(): + """Test that NLL computation is correct""" + # randomized test + spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + mask = np.random.rand(10, 100, 10) + # make sure spikes and rates have different nans + spikes[mask < 0.1] = np.nan + rates[mask > 0.9] = np.nan + + with pytest.raises(AssertionError): + neg_log_likelihood(rates, spikes) + + +def test_bits_per_spike(): + for _ in range(20): + spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + null_rates = np.tile( + spikes.mean(axis=(0, 1), keepdims=True), + (spikes.shape[0], spikes.shape[1], 1), + ).squeeze() + + expected_rate_nll = np.sum( + rates - spikes * np.log(rates) + gammaln(spikes + 1.0) + ) + expected_null_nll = np.sum( + null_rates - spikes * np.log(null_rates) + gammaln(spikes + 1.0) + ) + expected_bps = ( + (expected_null_nll - expected_rate_nll) / np.sum(spikes) / np.log(2) + ) + actual_bps = bits_per_spike(rates, spikes) + assert np.isclose(expected_bps, actual_bps) + + +def test_bits_per_spike_drop_nans(): + for _ in range(20): + spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) + rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) + mask = np.random.rand(10, 100, 10) > 0.9 + spikes[mask] = np.nan + if np.random.rand() > 0.5: # rates does not have to have nans + rates[mask] = np.nan + null_rates = np.tile( + np.nanmean(spikes, axis=(0, 1), keepdims=True), + (spikes.shape[0], spikes.shape[1], 1), + ).squeeze() + + expected_rate_nll = np.sum( + rates[~mask] + - spikes[~mask] * np.log(rates[~mask]) + + gammaln(spikes[~mask] + 1.0) + ) + expected_null_nll = np.sum( + null_rates[~mask] + - spikes[~mask] * np.log(null_rates[~mask]) + + gammaln(spikes[~mask] + 1.0) + ) + expected_bps = ( + (expected_null_nll - expected_rate_nll) / np.nansum(spikes) / np.log(2) + ) + actual_bps = bits_per_spike(rates, spikes) + assert np.isclose(expected_bps, actual_bps) + + +# -- Ridge regression --------------- + + +def test_fit_and_eval_decoder(): + rng = np.random.default_rng(0) + x = rng.standard_normal(size=(1000, 10)) + y = x @ rng.standard_normal(size=(10, 2)) + + # noiseless should have high R^2 + score = fit_and_eval_decoder( + train_rates=x[:800], + train_behavior=y[:800], + eval_rates=x[800:], + eval_behavior=y[800:], + ) + assert score > 0.95 + + # with noise should still have decent R^2 + y += rng.standard_normal(size=(1000, 2)) * 0.1 + score = fit_and_eval_decoder( + train_rates=x[:800], + train_behavior=y[:800], + eval_rates=x[800:], + eval_behavior=y[800:], + ) + assert score > 0.25 # arbitrary heuristic + + # regressing on noise should have poor R^2 + y = rng.standard_normal(size=(1000, 2)) + score = fit_and_eval_decoder( + train_rates=x[:800], + train_behavior=y[:800], + eval_rates=x[800:], + eval_behavior=y[800:], + ) + assert score < 0.95 # arbitrary heuristic + + +# -- PSTH evaluation + +# def test_eval_psth(): +# return