From 4ef34c9610b50b66f8110e5fae3e144ec1eb8979 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 4 Nov 2024 14:04:35 +0000 Subject: [PATCH 01/34] Added bonsai environment --- .../.bonsai/Bonsai.config | 103 ++++++++++++++++++ .../.bonsai/NuGet.config | 8 ++ .../.bonsai/Setup.ps1 | 21 ++++ .../.bonsai/Setup.sh | 41 +++++++ .../.bonsai/activate | 15 +++ .../.bonsai/deactivate | 8 ++ .../.bonsai/run | 58 ++++++++++ 7 files changed, 254 insertions(+) create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/NuGet.config create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.ps1 create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.sh create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/activate create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/deactivate create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/run diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config new file mode 100644 index 0000000..d09f619 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config @@ -0,0 +1,103 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/NuGet.config b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/NuGet.config new file mode 100644 index 0000000..97e8b73 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/NuGet.config @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.ps1 b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.ps1 new file mode 100644 index 0000000..76b5c46 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.ps1 @@ -0,0 +1,21 @@ +Push-Location $PSScriptRoot +if (!(Test-Path "./Bonsai.exe")) { + $release = "https://github.com/bonsai-rx/bonsai/releases/latest/download/Bonsai.zip" + $configPath = "./Bonsai.config" + if (Test-Path $configPath) { + [xml]$config = Get-Content $configPath + $bootstrapper = $config.PackageConfiguration.Packages.Package.where{$_.id -eq 'Bonsai'} + if ($bootstrapper) { + $version = $bootstrapper.version + $release = "https://github.com/bonsai-rx/bonsai/releases/download/$version/Bonsai.zip" + } + } + Invoke-WebRequest $release -OutFile "temp.zip" + Move-Item -Path "NuGet.config" "temp.config" -ErrorAction SilentlyContinue + Expand-Archive "temp.zip" -DestinationPath "." -Force + Move-Item -Path "temp.config" "NuGet.config" -Force -ErrorAction SilentlyContinue + Remove-Item -Path "temp.zip" + Remove-Item -Path "Bonsai32.exe" +} +& .\Bonsai.exe --no-editor +Pop-Location \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.sh b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.sh new file mode 100644 index 0000000..941d850 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.sh @@ -0,0 +1,41 @@ +#! /bin/bash + +SETUP_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" + +DEFAULT_VERSION="latest" +VERSION="$DEFAULT_VERSION" + +while [[ "$#" -gt 0 ]]; do + case $1 in + --version) VERSION="$2"; shift ;; + *) echo "Unknown parameter passed: $1"; exit 1 ;; + esac + shift +done + +echo "Setting up Bonsai v=$VERSION environment..." + +if [ ! -f "$SETUP_SCRIPT_DIR/Bonsai.exe" ]; then + CONFIG="$SETUP_SCRIPT_DIR/Bonsai.config" + if [ -f "$CONFIG" ]; then + DETECTED=$(xmllint --xpath '//PackageConfiguration/Packages/Package[@id="Bonsai"]/@version' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//') + echo "Version detected v=$DETECTED." + RELEASE="https://github.com/bonsai-rx/bonsai/releases/download/$DETECTED/Bonsai.zip" + else + if [ $VERSION = "latest" ]; then + RELEASE="https://github.com/bonsai-rx/bonsai/releases/latest/download/Bonsai.zip" + else + RELEASE="https://github.com/bonsai-rx/bonsai/releases/download/$VERSION/Bonsai.zip" + fi + fi + echo "Download URL: $RELEASE" + wget $RELEASE -O "$SETUP_SCRIPT_DIR/temp.zip" + mv -f "$SETUP_SCRIPT_DIR/NuGet.config" "$SETUP_SCRIPT_DIR/temp.config" + unzip -d "$SETUP_SCRIPT_DIR" -o "$SETUP_SCRIPT_DIR/temp.zip" + mv -f "$SETUP_SCRIPT_DIR/temp.config" "$SETUP_SCRIPT_DIR/NuGet.config" + rm -rf "$SETUP_SCRIPT_DIR/temp.zip" + rm -rf "$SETUP_SCRIPT_DIR/Bonsai32.exe" +fi + +source "$SETUP_SCRIPT_DIR/activate" +source "$SETUP_SCRIPT_DIR/run" --no-editor diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/activate b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/activate new file mode 100644 index 0000000..ddf75f3 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/activate @@ -0,0 +1,15 @@ +#!/bin/bash +# activate.sh +if [[ -v BONSAI_EXE_PATH ]]; then + echo "Error! Cannot have multiple bonsai environments activated at the same time. Please deactivate the current environment before activating the new one." + return +fi +BONSAI_ENV_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" +export BONSAI_ENV_DIR +export BONSAI_EXE_PATH="$BONSAI_ENV_DIR/Bonsai.exe" +export ORIGINAL_PS1="$PS1" +export PS1="($(basename "$BONSAI_ENV_DIR")) $PS1" +alias bonsai='source "$BONSAI_ENV_DIR"/run' +alias bonsai-clean='GTK_DATA_PREFIX= source "$BONSAI_ENV_DIR"/run' +alias deactivate='source "$BONSAI_ENV_DIR"/deactivate' +echo "Activated bonsai environment in $BONSAI_ENV_DIR" \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/deactivate b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/deactivate new file mode 100644 index 0000000..43233d9 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/deactivate @@ -0,0 +1,8 @@ +#!/bin/bash +unset BONSAI_EXE_PATH +export PS1="$ORIGINAL_PS1" +unset ORIGINAL_PS1 +unalias bonsai +unalias bonsai-clean +unalias deactivate +echo "Deactivated bonsai environment." \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/run b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/run new file mode 100644 index 0000000..bffd6cf --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/run @@ -0,0 +1,58 @@ +#!/bin/bash +# run.sh + +SETUP_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" +CONFIG="$SETUP_SCRIPT_DIR/Bonsai.config" + +cleanup() { + update_paths_to_windows +} + +update_paths_to_linux() { + ASSEMBLYLOCATIONS=$(xmllint --xpath '//PackageConfiguration/AssemblyLocations/AssemblyLocation/@location' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//') + for ASSEMBLYLOCATION in $ASSEMBLYLOCATIONS; do + NEWASSEMBLYLOCATION="${ASSEMBLYLOCATION//\\/\/}" + xmlstarlet edit --inplace --update "/PackageConfiguration/AssemblyLocations/AssemblyLocation[@location='$ASSEMBLYLOCATION']/@location" --value "$NEWASSEMBLYLOCATION" "$CONFIG" + done + + LIBRARYFOLDERS=$(xmllint --xpath '//PackageConfiguration/LibraryFolders/LibraryFolder/@path' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//') + for LIBRARYFOLDER in $LIBRARYFOLDERS; do + NEWLIBRARYFOLDER="${LIBRARYFOLDER//\\/\/}" + xmlstarlet edit --inplace --update "//PackageConfiguration/LibraryFolders/LibraryFolder[@path='$LIBRARYFOLDER']/@path" --value "$NEWLIBRARYFOLDER" "$CONFIG" + done +} + +update_paths_to_windows() { + ASSEMBLYLOCATIONS=$(xmllint --xpath '//PackageConfiguration/AssemblyLocations/AssemblyLocation/@location' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//') + for ASSEMBLYLOCATION in $ASSEMBLYLOCATIONS; do + NEWASSEMBLYLOCATION="${ASSEMBLYLOCATION//\//\\}" + xmlstarlet edit --inplace --update "/PackageConfiguration/AssemblyLocations/AssemblyLocation[@location='$ASSEMBLYLOCATION']/@location" --value "$NEWASSEMBLYLOCATION" "$CONFIG" + done + + LIBRARYFOLDERS=$(xmllint --xpath '//PackageConfiguration/LibraryFolders/LibraryFolder/@path' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//') + for LIBRARYFOLDER in $LIBRARYFOLDERS; do + NEWLIBRARYFOLDER="${LIBRARYFOLDER//\//\\}" + xmlstarlet edit --inplace --update "//PackageConfiguration/LibraryFolders/LibraryFolder[@path='$LIBRARYFOLDER']/@path" --value "$NEWLIBRARYFOLDER" "$CONFIG" + done +} + +if [[ -v BONSAI_EXE_PATH ]]; then + if [ ! -f "$BONSAI_EXE_PATH" ]; then + bash "$BONSAI_ENV_DIR"/Setup.sh + bash "$BONSAI_ENV_DIR"/run "$@" + else + BONSAI_VERSION=$(xmllint --xpath "//PackageConfiguration/Packages/Package[@id='Bonsai']/@version" "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//') + if [[ -z ${BONSAI_VERSION+x} ]] && [ "$BONSAI_VERSION" \< "2.8.4" ]; then + echo "Updating paths to Linux format..." + trap cleanup EXIT INT TERM + update_paths_to_linux + mono "$BONSAI_EXE_PATH" "$@" + cleanup + else + mono "$BONSAI_EXE_PATH" "$@" + fi + fi +else + echo "BONSAI_EXE_PATH is not set. Please set the path to the Bonsai executable." + return +fi \ No newline at end of file From dd06f6d0f55bef7f0b59ed9999bd857e93a9e90c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 4 Nov 2024 14:32:04 +0000 Subject: [PATCH 02/34] Added decoder python library --- .../decoder/__init__.py | 4 + .../decoder/core.py | 79 +++++++++++++++++++ .../decoder/data_iterator.py | 33 ++++++++ .../decoder/data_loader.py | 73 +++++++++++++++++ .../decoder/model_loader.py | 13 +++ 5 files changed, 202 insertions(+) create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py new file mode 100644 index 0000000..cca164a --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py @@ -0,0 +1,4 @@ +from .core import SortedSpikeDecoder +from .data_loader import DataLoader +from .data_iterator import DataIterator +from .model_loader import ModelLoader \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py new file mode 100644 index 0000000..8ffc78f --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py @@ -0,0 +1,79 @@ +import replay_trajectory_classification as rtc +from replay_trajectory_classification.core import scaled_likelihood, get_centers +from replay_trajectory_classification.likelihoods import _SORTED_SPIKES_ALGORITHMS +# from likelihoods import _SORTED_SPIKES_ALGORITHMS +from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import combined_likelihood, poisson_log_likelihood + +import numpy as np + +class SortedSpikeDecoder: + def __init__(self, model_dict: dict): + self.model = model_dict["decoder"] + self.Fs = model_dict["Fs"] + self.spikes = model_dict["binned_spikes_times"] + self.is_track_interior = self.model.environment.is_track_interior_.ravel(order="F") + self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior) + self.n_position_bins = self.is_track_interior.shape[0] + + self.initial_conditions = self.model.initial_conditions_[self.is_track_interior].astype(float) + self.state_transition = self.model.state_transition_[self.st_interior_ind].astype(float) + self.place_fields = np.asarray(self.model.place_fields_) + self.position_centers = get_centers(self.model.environment.edges_[0]), + + self.posterior = None + super().__init__() + + def decode_spikes( + self, + spikes: np.ndarray + ): + # likelihood = scaled_likelihood(_SORTED_SPIKES_ALGORITHMS[self.model.sorted_spikes_algorithm][1](spikes[np.newaxis], self.place_fields)) + # likelihood = likelihood[:, self.is_track_interior].astype(float) + + conditional_intensity = np.clip(self.place_fields, a_min=1e-15, a_max=None) + + log_likelihood = 0 + for spike, ci in zip(spikes, conditional_intensity.T): + log_likelihood += poisson_log_likelihood(spike[np.newaxis], ci) + + mask = np.ones_like(self.is_track_interior, dtype=float) + mask[~self.is_track_interior] = np.nan + + likelihood = scaled_likelihood(log_likelihood * mask) + likelihood = likelihood[:, self.is_track_interior].astype(float) + + if self.posterior is None: + self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) + self.posterior[0, self.is_track_interior] = self.initial_conditions * likelihood[0] + + else: + self.posterior[0, self.is_track_interior] = self.state_transition.T @ self.posterior[0, self.is_track_interior] * likelihood[0] + + norm = np.nansum(self.posterior[0]) + self.posterior[0] /= norm + + return self.posterior + + # def decode_spikes( + # self, + # data: tuple[list, list] + # ) -> tuple[np.ndarray, float]: + + # likelihood = scaled_likelihood(_SORTED_SPIKES_ALGORITHMS[model.sorted_spikes_algorithm][1](spikes, np.asarray(model.place_fields_))) + # likelihood = likelihood[:, is_track_interior].astype(float) + + # n_time = likelihood.shape[0] + # posterior = np.zeros_like(likelihood) + + # posterior[0] = initial_conditions.copy() * likelihood[0] + # norm = np.nansum(posterior[0]) + # log_data_likelihood = np.log(norm) + # posterior[0] /= norm + + # for k in np.arange(1, n_time): + # posterior[k] = state_transition.T @ posterior[k - 1] * likelihood[k] + # norm = np.nansum(posterior[k]) + # log_data_likelihood += np.log(norm) + # posterior[k] /= norm + + # return posterior, log_data_likelihood \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py new file mode 100644 index 0000000..7bd19f9 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py @@ -0,0 +1,33 @@ +import numpy as np + +class DataIterator: + + def __init__(self, + data: dict, + start_index: int = 0): + self.data = data + self.position_bins = self.data["position_bins"] + self.index = start_index + super().__init__() + + def next(self, + loop: bool = True) -> tuple[list, list]: + + output = None + position_data = self.data["position_data"] + spike_times = self.data["spike_times"] + decoding_results = self.data["decoding_results"] + + if self.index > len(position_data) and loop: + self.index = 0 + + if self.index < len(position_data): + position = position_data[self.index] + spikes = spike_times[self.index] + decoding = decoding_results[self.index][np.newaxis] + + self.index += 1 + + output = (position, spikes, decoding, self.position_bins) + + return output \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py new file mode 100644 index 0000000..a0d2ae6 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py @@ -0,0 +1,73 @@ +import pickle +import pandas as pd +import numpy as np +import track_linearization as tl +import replay_trajectory_classification as rtc +import os + +class DataLoader: + def __init__(self): + super().__init__() + + @classmethod + def load_data(cls, + dataset_path: str = "../../../datasets/decoder_data", + bin_spikes: bool = True) -> dict: + + if len([file for file in os.listdir(dataset_path) if file == "position_info.pkl" or file == "sorted_spike_times.pkl" or file == "decoding_results.pkl"]) != 3: + raise Exception("Dataset incorrect. Missing at least one of the following files: 'position_info.pkl', 'sorted_spike_times.pkl', 'decoding_results.pkl'") + + position_data = pd.read_pickle(os.path.join(dataset_path, "position_info.pkl")) + position_index = position_data.index.to_numpy() + position_index = np.insert(position_index, 0, position_index[0] - (position_index[1] - position_index[0])) + position_data = position_data[["nose_x", "nose_y"]].to_numpy() + + node_positions = [(120.0, 100.0), + ( 5.0, 100.0), + ( 5.0, 55.0), + (120.0, 55.0), + ( 5.0, 8.5), + (120.0, 8.5), + ] + edges = [ + (3, 2), + (0, 1), + (1, 2), + (5, 4), + (4, 2), + ] + track_graph = rtc.make_track_graph(node_positions, edges) + + edge_order = [ + (3, 2), + (0, 1), + (1, 2), + (5, 4), + (4, 2), + ] + + edge_spacing = [16, 0, 16, 0] + + linearized_positions = tl.get_linearized_position(position_data, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False) + position_data = linearized_positions.linear_position + + with open(os.path.join(dataset_path, "sorted_spike_times.pkl"), "rb") as f: + spike_times = pickle.load(f) + + if bin_spikes: + spike_mat = np.zeros((len(position_data), len(spike_times))) + for neuron in range(len(spike_times)): + spike_mat[:, neuron] = np.histogram(spike_times[neuron], position_index)[0] + spike_times = spike_mat + + with open(os.path.join(dataset_path, "decoding_results.pkl"), "rb") as f: + results = pickle.load(f)["decoding_results"] + position_bins = results.position.to_numpy() + decoding_results = results.acausal_posterior.to_numpy() + + return { + "position_data": position_data, + "spike_times": spike_times, + "decoding_results": decoding_results, + "position_bins": position_bins + } diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py new file mode 100644 index 0000000..54b94ac --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py @@ -0,0 +1,13 @@ +import pickle +from .core import SortedSpikeDecoder + +class ModelLoader: + def __init__(self): + super().__init__() + + @classmethod + def load_model(cls, + model_path: str = "../../../datasets/decoder_data/model.pkl") -> SortedSpikeDecoder: + with open(model_path, "rb") as f: + model = pickle.load(f) + return SortedSpikeDecoder(model) \ No newline at end of file From 32491dbb6da254caffcbeb546019348f42b6794d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 4 Nov 2024 14:32:23 +0000 Subject: [PATCH 03/34] Added workflow with extensions --- .../Extensions.csproj | 1 + .../Extensions/PositionBins.cs | 19 + .../Extensions/Posterior.cs | 42 +++ .../PositionDecoding.bonsai | 326 ++++++++++++++++++ 4 files changed, 388 insertions(+) create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions.csproj create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/PositionBins.cs create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/Posterior.cs create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions.csproj b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions.csproj new file mode 100644 index 0000000..e99fc36 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions.csproj @@ -0,0 +1 @@ +net472 \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/PositionBins.cs b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/PositionBins.cs new file mode 100644 index 0000000..f0e4a12 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/PositionBins.cs @@ -0,0 +1,19 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using Python.Runtime; +using Bonsai.ML.Python; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class PositionBins +{ + public IObservable Process(IObservable source) + { + return source.Select(value => (double[])PythonHelper.ConvertPythonObjectToCSharp(value)); + } +} diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/Posterior.cs b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/Posterior.cs new file mode 100644 index 0000000..21364bd --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/Posterior.cs @@ -0,0 +1,42 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using Python.Runtime; +using Bonsai.ML.Python; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Posterior +{ + public IObservable Process(IObservable source) + { + return source.Select(value => { + return new PosteriorData(value); + }); + } +} + +public class PosteriorData +{ + public PosteriorData(PyObject posterior) + { + _data = (double[,])PythonHelper.ConvertPythonObjectToCSharp(posterior); + _mapEstimate = 0; + for (int i = 1; i < _data.GetLength(1); i++) + { + if (_data[0, i] > _data[0, _mapEstimate]) + { + _mapEstimate = i; + } + } + } + public double[,] Data => _data; + private double[,] _data; + + public int MapEstimate => _mapEstimate; + private int _mapEstimate; +} diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai new file mode 100644 index 0000000..7d7eddd --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai @@ -0,0 +1,326 @@ + + + + + + + + + + + ImportDecoderLibrary + + + + Source1 + + + + import sys +import os +sys.path.append(os.getcwd()) +from decoder.data_loader import DataLoader +from decoder.data_iterator import DataIterator +from decoder.model_loader import ModelLoader + + + + + + + + + + + + ModelLoader + + + + + + + + model = ModelLoader.load_model() + + + + + + + + + + + + DataLoader + + + + + + + + data = DataLoader.load_data() + + + + + + + + + + + + + + + IterateData + + + + + + + + PT0S + PT0.1S + + + + + + + + iterator = DataIterator(data) + + + + + + + + + + + iterator.next() + + + + new(it[0] as Position, it[1] as Spikes, it[2] as Decoding, it[3] as PositionBins) + + + + + + + + + + + + + + + + + Data + + + Data + + + Spikes + + + + spikes + + + + + model.decode_spikes(spikes) + + + + + + + OnlineDecoderResults + + + Data + + + OnlinePosterior + + + Data + + + Position + + + Convert.ToDouble(it.ToString()) + + + Position + + + Data + + + Decoding + + + + + + OfflineDecoderResults + + + Data + + + OfflinePosterior + + + Data + + + PositionBins + + + + + + PositionBins + + + PositionBins + + + OnlineDecoderResults + + + MapEstimate + + + + + + + OnlinePositionEstimate + + + PositionBins + + + OfflineDecoderResults + + + MapEstimate + + + + + + + OfflinePositionEstimate + + + Visualizer + + + + OnlinePositionEstimate + + + + OfflinePositionEstimate + + + + Position + + + + OnlinePosterior + + + + OfflinePosterior + + + + true + true + 3 + 2 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file From 034fd526d11a4641aa0ffe30c340fd7fae9ba1e3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 5 Nov 2024 18:26:52 +0000 Subject: [PATCH 04/34] Added README --- .../README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/README.md diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/README.md b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/README.md new file mode 100644 index 0000000..843caed --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/README.md @@ -0,0 +1,13 @@ +# Position Decoding from Hippocampal Sorted Spikes + +In the following example, you can find how to use the decoder from [here](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/tree/master?tab=readme-ov-file) to decode an animals position from sorted hippocampal units sampled with tetrodes. + +### Dataset + +We thank Eric Denovellis for sharing his data and for his help with the decoder. Please cite his work: Eric L Denovellis, Anna K Gillespie, Michael E Coulter, Marielena Sosa, Jason E Chung, Uri T Eden, Loren M Frank (2021). Hippocampal replay of experience at real-world speeds. eLife 10:e64505. + +You can download the data [here](https://drive.google.com/file/d/1ddRC28w0U4_q3pcGfY-1vPHjO9mjEaJb/view?usp=sharing). The workflow expects the zip file to be extracted into the `datasets/decoder_data` folder. The workflow also expects the files to be renamed to just `sorted_spike_times.pkl` and `position_info.pkl`. + +### Python + +You need to install the [replay_trajectory_classification](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification) package into your python virtual environment. From 3d011971c10218f77f2aedbb0988dc3dcb77f28c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 5 Nov 2024 18:27:24 +0000 Subject: [PATCH 05/34] Added clusterless spike decoder model --- .../decoder/__init__.py | 2 +- .../decoder/core.py | 138 +++++++++++++----- .../decoder/data_loader.py | 77 +++++++++- .../decoder/model_loader.py | 15 +- 4 files changed, 184 insertions(+), 48 deletions(-) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py index cca164a..8378a5a 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py @@ -1,4 +1,4 @@ -from .core import SortedSpikeDecoder +from .core import SortedSpikeDecoder, ClusterlessSpikeDecoder from .data_loader import DataLoader from .data_iterator import DataIterator from .model_loader import ModelLoader \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py index 8ffc78f..2cd9a24 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py @@ -1,35 +1,123 @@ import replay_trajectory_classification as rtc from replay_trajectory_classification.core import scaled_likelihood, get_centers -from replay_trajectory_classification.likelihoods import _SORTED_SPIKES_ALGORITHMS -# from likelihoods import _SORTED_SPIKES_ALGORITHMS +from replay_trajectory_classification.likelihoods import _SORTED_SPIKES_ALGORITHMS, _ClUSTERLESS_ALGORITHMS from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import combined_likelihood, poisson_log_likelihood +from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity import numpy as np +class ClusterlessSpikeDecoder: + def __init__(self, model_dict: dict): + self.decoder = model_dict["decoder"] + self.Fs = model_dict["Fs"] + self.features = model_dict["features"] + + encoding_model = self.decoder.encoding_model_ + self.encoding_marks = encoding_model["encoding_marks"] + self.mark_std = encoding_model["mark_std"] + self.encoding_positions = encoding_model["encoding_positions"] + self.position_std = encoding_model["position_std"] + self.occupancy = encoding_model["occupancy"] + self.mean_rates = encoding_model["mean_rates"] + self.summed_ground_process_intensity = encoding_model["summed_ground_process_intensity"] + self.block_size = encoding_model["block_size"] + self.bin_diffusion_distances = encoding_model["bin_diffusion_distances"] + self.edges = encoding_model["edges"] + + self.place_bin_centers = self.decoder.environment.place_bin_centers_ + self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order="F") + self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior) + self.interior_place_bin_centers = np.asarray( + self.place_bin_centers[self.is_track_interior], dtype=np.float32 + ) + self.interior_occupancy = np.asarray( + self.occupancy[self.is_track_interior], dtype=np.float32 + ) + self.n_position_bins = self.is_track_interior.shape[0] + self.n_track_bins = self.is_track_interior.sum() + + self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) + self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) + + self.posterior = None + super().__init__() + + def decode(self, + multiunits: np.ndarray): + + log_likelihood = -self.summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) + + if not np.isnan(multiunits).all(): + multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] + + + for multiunit, enc_marks, enc_pos, mean_rate in zip( + multiunits.T, + self.encoding_marks, + self.encoding_positions, + self.mean_rates, + ): + is_spike = np.any(~np.isnan(multiunit)) + if is_spike: + decoding_marks = np.asarray( + multiunit, dtype=np.float32 + )[np.newaxis] + log_joint_mark_intensity = np.zeros( + (1, self.n_track_bins), dtype=np.float32 + ) + position_distance = estimate_position_distance( + self.interior_place_bin_centers, + np.asarray(enc_pos, dtype=np.float32), + self.position_std, + ).astype(np.float32) + log_joint_mark_intensity[0] = estimate_log_joint_mark_intensity( + decoding_marks, + enc_marks, + self.mark_std, + self.interior_occupancy, + mean_rate, + position_distance=position_distance, + ) + log_likelihood[:, self.is_track_interior] += np.nan_to_num( + log_joint_mark_intensity + ) + + log_likelihood[:, ~self.is_track_interior] = np.nan + likelihood = scaled_likelihood(log_likelihood) + + if self.posterior is None: + self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) + self.posterior[0, self.is_track_interior] = self.initial_conditions * likelihood[0, self.is_track_interior] + + else: + self.posterior[0, self.is_track_interior] = self.state_transition.T @ self.posterior[0, self.is_track_interior] * likelihood[0, self.is_track_interior] + + norm = np.nansum(self.posterior[0]) + self.posterior[0] /= norm + + return self.posterior + class SortedSpikeDecoder: def __init__(self, model_dict: dict): - self.model = model_dict["decoder"] + self.decoder = model_dict["decoder"] self.Fs = model_dict["Fs"] self.spikes = model_dict["binned_spikes_times"] - self.is_track_interior = self.model.environment.is_track_interior_.ravel(order="F") + self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order="F") self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior) self.n_position_bins = self.is_track_interior.shape[0] - self.initial_conditions = self.model.initial_conditions_[self.is_track_interior].astype(float) - self.state_transition = self.model.state_transition_[self.st_interior_ind].astype(float) - self.place_fields = np.asarray(self.model.place_fields_) - self.position_centers = get_centers(self.model.environment.edges_[0]), + self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) + self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) + self.place_fields = np.asarray(self.decoder.place_fields_) + self.position_centers = get_centers(self.decoder.environment.edges_[0]), self.posterior = None super().__init__() - def decode_spikes( + def decode( self, spikes: np.ndarray ): - # likelihood = scaled_likelihood(_SORTED_SPIKES_ALGORITHMS[self.model.sorted_spikes_algorithm][1](spikes[np.newaxis], self.place_fields)) - # likelihood = likelihood[:, self.is_track_interior].astype(float) - conditional_intensity = np.clip(self.place_fields, a_min=1e-15, a_max=None) log_likelihood = 0 @@ -52,28 +140,4 @@ def decode_spikes( norm = np.nansum(self.posterior[0]) self.posterior[0] /= norm - return self.posterior - - # def decode_spikes( - # self, - # data: tuple[list, list] - # ) -> tuple[np.ndarray, float]: - - # likelihood = scaled_likelihood(_SORTED_SPIKES_ALGORITHMS[model.sorted_spikes_algorithm][1](spikes, np.asarray(model.place_fields_))) - # likelihood = likelihood[:, is_track_interior].astype(float) - - # n_time = likelihood.shape[0] - # posterior = np.zeros_like(likelihood) - - # posterior[0] = initial_conditions.copy() * likelihood[0] - # norm = np.nansum(posterior[0]) - # log_data_likelihood = np.log(norm) - # posterior[0] /= norm - - # for k in np.arange(1, n_time): - # posterior[k] = state_transition.T @ posterior[k - 1] * likelihood[k] - # norm = np.nansum(posterior[k]) - # log_data_likelihood += np.log(norm) - # posterior[k] /= norm - - # return posterior, log_data_likelihood \ No newline at end of file + return self.posterior \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py index a0d2ae6..18c229a 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py @@ -10,12 +10,12 @@ def __init__(self): super().__init__() @classmethod - def load_data(cls, + def load_sorted_spike_data(cls, dataset_path: str = "../../../datasets/decoder_data", bin_spikes: bool = True) -> dict: - if len([file for file in os.listdir(dataset_path) if file == "position_info.pkl" or file == "sorted_spike_times.pkl" or file == "decoding_results.pkl"]) != 3: - raise Exception("Dataset incorrect. Missing at least one of the following files: 'position_info.pkl', 'sorted_spike_times.pkl', 'decoding_results.pkl'") + if len([file for file in os.listdir(dataset_path) if file == "position_info.pkl" or file == "sorted_spike_times.pkl" or file == "sorted_spike_decoding_results.pkl"]) != 3: + raise Exception("Dataset incorrect. Missing at least one of the following files: 'position_info.pkl', 'sorted_spike_times.pkl', 'sorted_spike_decoding_results.pkl'") position_data = pd.read_pickle(os.path.join(dataset_path, "position_info.pkl")) position_index = position_data.index.to_numpy() @@ -60,10 +60,10 @@ def load_data(cls, spike_mat[:, neuron] = np.histogram(spike_times[neuron], position_index)[0] spike_times = spike_mat - with open(os.path.join(dataset_path, "decoding_results.pkl"), "rb") as f: + with open(os.path.join(dataset_path, "sorted_spike_decoding_results.pkl"), "rb") as f: results = pickle.load(f)["decoding_results"] - position_bins = results.position.to_numpy() - decoding_results = results.acausal_posterior.to_numpy() + position_bins = results.position.to_numpy()[np.newaxis] + decoding_results = results.acausal_posterior.to_numpy()[:,np.newaxis] return { "position_data": position_data, @@ -71,3 +71,68 @@ def load_data(cls, "decoding_results": decoding_results, "position_bins": position_bins } + + @classmethod + def load_clusterless_spike_data(cls, + dataset_path: str = "../../../datasets/decoder_data") -> dict: + + if len([file for file in os.listdir(dataset_path) if file == "position_info.pkl" or file == "clusterless_spike_times.pkl" or file == "clusterless_spike_features.pkl" or file == "clusterless_spike_decoding_results.pkl"]) != 4: + raise Exception("Dataset incorrect. Missing at least one of the following files: 'position_info.pkl', 'clusterless_spike_times.pkl', 'clusterless_spike_features.pkl', 'clusterless_spike_decoding_results.pkl'") + + position_data = pd.read_pickle(os.path.join(dataset_path, "position_info.pkl")) + position_index = position_data.index.to_numpy() + position_index = np.insert(position_index, 0, position_index[0] - (position_index[1] - position_index[0])) + position_data = position_data[["nose_x", "nose_y"]].to_numpy() + + node_positions = [(120.0, 100.0), + ( 5.0, 100.0), + ( 5.0, 55.0), + (120.0, 55.0), + ( 5.0, 8.5), + (120.0, 8.5), + ] + edges = [ + (3, 2), + (0, 1), + (1, 2), + (5, 4), + (4, 2), + ] + track_graph = rtc.make_track_graph(node_positions, edges) + + edge_order = [ + (3, 2), + (0, 1), + (1, 2), + (5, 4), + (4, 2), + ] + + edge_spacing = [16, 0, 16, 0] + + linearized_positions = tl.get_linearized_position(position_data, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False) + position_data = linearized_positions.linear_position + + with open(os.path.join(dataset_path, "clusterless_spike_times.pkl"), "rb") as f: + spike_times = pickle.load(f) + + with open(os.path.join(dataset_path, "clusterless_spike_features.pkl"), "rb") as f: + spike_features = pickle.load(f) + + features = np.ones((len(position_data), len(spike_features[0][0]), len(spike_times)), dtype=float) * np.nan + for n in range(len(spike_times)): + in_spikes_window = np.digitize(spike_times[n], position_index) + features[in_spikes_window, :, n] = spike_features[n] + + with open(os.path.join(dataset_path, "clusterless_spike_decoding_results.pkl"), "rb") as f: + results = pickle.load(f)["decoding_results"] + position_bins = results.position.to_numpy()[np.newaxis] + decoding_results = results.acausal_posterior.to_numpy()[:,np.newaxis] + + return { + "position_data": position_data, + "spike_times": spike_times, + "features": features, + "decoding_results": decoding_results, + "position_bins": position_bins + } diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py index 54b94ac..4b5652f 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py @@ -1,13 +1,20 @@ import pickle -from .core import SortedSpikeDecoder +from .core import SortedSpikeDecoder, ClusterlessSpikeDecoder class ModelLoader: def __init__(self): super().__init__() @classmethod - def load_model(cls, - model_path: str = "../../../datasets/decoder_data/model.pkl") -> SortedSpikeDecoder: + def load_sorted_spike_model(cls, + model_path: str = "../../../datasets/decoder_data/sorted_spike_decoder.pkl") -> SortedSpikeDecoder: with open(model_path, "rb") as f: model = pickle.load(f) - return SortedSpikeDecoder(model) \ No newline at end of file + return SortedSpikeDecoder(model) + + @classmethod + def load_clusterless_spike_model(cls, + model_path: str = "../../../datasets/decoder_data/clusterless_spike_decoder.pkl") -> ClusterlessSpikeDecoder: + with open(model_path, "rb") as f: + model = pickle.load(f) + return ClusterlessSpikeDecoder(model) \ No newline at end of file From 99e1bfec9122ee83dda6a4c8f9ffd3d72dc8a186 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 5 Nov 2024 18:27:41 +0000 Subject: [PATCH 06/34] Changed name to sorted spikes --- ...PositionDecoding.bonsai => SortedSpikes.bonsai} | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) rename examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/{PositionDecoding.bonsai => SortedSpikes.bonsai} (96%) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/SortedSpikes.bonsai similarity index 96% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai rename to examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/SortedSpikes.bonsai index 7d7eddd..572bc33 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/PositionDecoding.bonsai +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/SortedSpikes.bonsai @@ -26,9 +26,7 @@ import sys import os sys.path.append(os.getcwd()) -from decoder.data_loader import DataLoader -from decoder.data_iterator import DataIterator -from decoder.model_loader import ModelLoader +from decoder import * @@ -48,7 +46,7 @@ from decoder.model_loader import ModelLoader - model = ModelLoader.load_model() + model = ModelLoader.load_sorted_spike_model() @@ -68,7 +66,7 @@ from decoder.model_loader import ModelLoader - data = DataLoader.load_data() + data = DataLoader.load_sorted_spike_data() @@ -114,6 +112,9 @@ from decoder.model_loader import ModelLoader iterator.next() + + it[0] + new(it[0] as Position, it[1] as Spikes, it[2] as Decoding, it[3] as PositionBins) @@ -128,6 +129,7 @@ from decoder.model_loader import ModelLoader + @@ -147,7 +149,7 @@ from decoder.model_loader import ModelLoader - model.decode_spikes(spikes) + model.decode(spikes) From 501a622e8096f16c22760ceb5c9fa391f832ee88 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 5 Nov 2024 18:27:59 +0000 Subject: [PATCH 07/34] Added clusterless spikes bonsai workflow --- .../ClusterlessSpikes.bonsai | 328 ++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai new file mode 100644 index 0000000..a9175cf --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai @@ -0,0 +1,328 @@ + + + + + + + + + + + ImportDecoderLibrary + + + + Source1 + + + + import sys +import os +sys.path.append(os.getcwd()) +from decoder import * + + + + + + + + + + + + ModelLoader + + + + + + + + model = ModelLoader.load_clusterless_spike_model() + + + + + + + + + + + + DataLoader + + + + + + + + data = DataLoader.load_clusterless_spike_data() + + + + + + + + + + + + + + + IterateData + + + + + + + + PT0S + PT0.1S + + + + + + + + iterator = DataIterator(data) + + + + + + + + + + + iterator.next() + + + + it[0] + + + new(it[0] as Position, it[1] as Spikes, it[2] as Features, it[3] as Decoding, it[4] as PositionBins) + + + + + + + + + + + + + + + + + + Data + + + Data + + + Features + + + + spikes + + + + + model.decode(spikes) + + + + + + + OnlineDecoderResults + + + Data + + + OnlinePosterior + + + Data + + + Position + + + Convert.ToDouble(it.ToString()) + + + Position + + + Data + + + Decoding + + + + + + OfflineDecoderResults + + + Data + + + OfflinePosterior + + + Data + + + PositionBins + + + + + + PositionBins + + + PositionBins + + + OnlineDecoderResults + + + MapEstimate + + + + + + + OnlinePositionEstimate + + + PositionBins + + + OfflineDecoderResults + + + MapEstimate + + + + + + + OfflinePositionEstimate + + + Visualizer + + + + OnlinePositionEstimate + + + + OfflinePositionEstimate + + + + Position + + + + OnlinePosterior + + + + OfflinePosterior + + + + true + true + 3 + 2 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file From dcd18b5e171aefe2fce58f4d051d205ce4994617 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 5 Nov 2024 18:28:22 +0000 Subject: [PATCH 08/34] Updated data iterator to handle outputting position bins --- .../decoder/data_iterator.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py index 7bd19f9..98c5eb1 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py @@ -6,28 +6,20 @@ def __init__(self, data: dict, start_index: int = 0): self.data = data - self.position_bins = self.data["position_bins"] + self.keys = list(self.data.keys()) self.index = start_index super().__init__() - def next(self, - loop: bool = True) -> tuple[list, list]: + def next(self) -> tuple[list, list]: - output = None - position_data = self.data["position_data"] - spike_times = self.data["spike_times"] - decoding_results = self.data["decoding_results"] + output = [] - if self.index > len(position_data) and loop: - self.index = 0 + for key in self.keys: + try: + output.append(self.data[key][self.index]) + except IndexError: + output.append(self.data[key][0]) - if self.index < len(position_data): - position = position_data[self.index] - spikes = spike_times[self.index] - decoding = decoding_results[self.index][np.newaxis] + self.index += 1 - self.index += 1 - - output = (position, spikes, decoding, self.position_bins) - - return output \ No newline at end of file + return (output, self.keys) \ No newline at end of file From 02f8cc0e6f313b3ec4bb42ca86dff311ed5e2df8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:53:24 +0000 Subject: [PATCH 09/34] Changed name of core to decoder --- .../decoder/{core.py => decoder.py} | 84 +++++++------------ 1 file changed, 32 insertions(+), 52 deletions(-) rename examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/{core.py => decoder.py} (64%) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/decoder.py similarity index 64% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py rename to examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/decoder.py index 2cd9a24..1f659d0 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/core.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/decoder.py @@ -4,9 +4,19 @@ from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import combined_likelihood, poisson_log_likelihood from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity +from .likelihood import LIKELIHOOD_FUNCTION + import numpy as np +import pickle as pkl + +class Decoder(): + def __init__(self): + super().__init__() + + def decode(self): + raise NotImplementedError -class ClusterlessSpikeDecoder: +class ClusterlessSpikeDecoder(Decoder): def __init__(self, model_dict: dict): self.decoder = model_dict["decoder"] self.Fs = model_dict["Fs"] @@ -39,51 +49,27 @@ def __init__(self, model_dict: dict): self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) + self.likelihood_funcion = LIKELIHOOD_FUNCTION[self.decoder.clusterless_algorithm] + self.posterior = None super().__init__() def decode(self, multiunits: np.ndarray): - log_likelihood = -self.summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) - - if not np.isnan(multiunits).all(): - multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] - - - for multiunit, enc_marks, enc_pos, mean_rate in zip( - multiunits.T, - self.encoding_marks, - self.encoding_positions, - self.mean_rates, - ): - is_spike = np.any(~np.isnan(multiunit)) - if is_spike: - decoding_marks = np.asarray( - multiunit, dtype=np.float32 - )[np.newaxis] - log_joint_mark_intensity = np.zeros( - (1, self.n_track_bins), dtype=np.float32 - ) - position_distance = estimate_position_distance( - self.interior_place_bin_centers, - np.asarray(enc_pos, dtype=np.float32), - self.position_std, - ).astype(np.float32) - log_joint_mark_intensity[0] = estimate_log_joint_mark_intensity( - decoding_marks, - enc_marks, - self.mark_std, - self.interior_occupancy, - mean_rate, - position_distance=position_distance, - ) - log_likelihood[:, self.is_track_interior] += np.nan_to_num( - log_joint_mark_intensity - ) - - log_likelihood[:, ~self.is_track_interior] = np.nan - likelihood = scaled_likelihood(log_likelihood) + likelihood = self.likelihood_function( + multiunits, + self.summed_ground_process_intensity, + self.encoding_marks, + self.encoding_positions, + self.mean_rates, + self.is_track_interior, + self.interior_place_bin_centers, + self.position_std, + self.mark_std, + self.interior_occupancy, + self.n_track_bins + ) if self.posterior is None: self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) @@ -97,7 +83,7 @@ def decode(self, return self.posterior -class SortedSpikeDecoder: +class SortedSpikeDecoder(Decoder): def __init__(self, model_dict: dict): self.decoder = model_dict["decoder"] self.Fs = model_dict["Fs"] @@ -109,7 +95,10 @@ def __init__(self, model_dict: dict): self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) self.place_fields = np.asarray(self.decoder.place_fields_) - self.position_centers = get_centers(self.decoder.environment.edges_[0]), + self.position_centers = get_centers(self.decoder.environment.edges_[0]) + self.conditional_intensity = np.clip(self.place_fields, a_min=1e-15, a_max=None) + + self.likelihood_function = LIKELIHOOD_FUNCTION[self.decoder.sorted_spikes_algorithm] self.posterior = None super().__init__() @@ -118,17 +107,8 @@ def decode( self, spikes: np.ndarray ): - conditional_intensity = np.clip(self.place_fields, a_min=1e-15, a_max=None) - - log_likelihood = 0 - for spike, ci in zip(spikes, conditional_intensity.T): - log_likelihood += poisson_log_likelihood(spike[np.newaxis], ci) - mask = np.ones_like(self.is_track_interior, dtype=float) - mask[~self.is_track_interior] = np.nan - - likelihood = scaled_likelihood(log_likelihood * mask) - likelihood = likelihood[:, self.is_track_interior].astype(float) + likelihood = self.likelihood_function(spikes, self.conditional_intensity, self.is_track_interior) if self.posterior is None: self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) From 8ad156c4a0cb0bc8d17f17e638fdfec84faa3d86 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:53:36 +0000 Subject: [PATCH 10/34] Added likelihood module --- .../decoder/likelihood.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py new file mode 100644 index 0000000..00aa9f6 --- /dev/null +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py @@ -0,0 +1,67 @@ +from replay_trajectory_classification.core import scaled_likelihood +from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity +from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import poisson_log_likelihood + +import numpy as np + +def spiking_likelihood_kde(spikes, conditional_intensity, is_track_interior): + + log_likelihood = 0 + for spike, ci in zip(spikes, conditional_intensity.T): + log_likelihood += poisson_log_likelihood(spike[np.newaxis], ci) + + mask = np.ones_like(is_track_interior, dtype=float) + mask[~is_track_interior] = np.nan + + likelihood = scaled_likelihood(log_likelihood * mask) + likelihood = likelihood[:, is_track_interior].astype(float) + + return likelihood + +def multiunit_likelihood(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): + log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) + + if not np.isnan(multiunits).all(): + multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] + + + for multiunit, enc_marks, enc_pos, mean_rate in zip( + multiunits.T, + encoding_marks, + encoding_positions, + mean_rates, + ): + is_spike = np.any(~np.isnan(multiunit)) + if is_spike: + decoding_marks = np.asarray( + multiunit, dtype=np.float32 + )[np.newaxis] + log_joint_mark_intensity = np.zeros( + (1, n_track_bins), dtype=np.float32 + ) + position_distance = estimate_position_distance( + interior_place_bin_centers, + np.asarray(enc_pos, dtype=np.float32), + position_std, + ).astype(np.float32) + log_joint_mark_intensity[0] = estimate_log_joint_mark_intensity( + decoding_marks, + enc_marks, + mark_std, + interior_occupancy, + mean_rate, + position_distance=position_distance, + ) + log_likelihood[:, is_track_interior] += np.nan_to_num( + log_joint_mark_intensity + ) + + log_likelihood[:, ~is_track_interior] = np.nan + likelihood = scaled_likelihood(log_likelihood) + + return likelihood + +LIKELIHOOD_FUNCTION = { + "multiunit_likelihood": multiunit_likelihood, + "spiking_likelihood_kde": spiking_likelihood_kde +} \ No newline at end of file From 4d5d5e0b5d6477bb113ee751188973d343c88a90 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:53:55 +0000 Subject: [PATCH 11/34] Updated model loader to use decoder module --- .../decoder/model_loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py index 4b5652f..d972f7b 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py @@ -1,19 +1,19 @@ import pickle -from .core import SortedSpikeDecoder, ClusterlessSpikeDecoder +from .decoder import SortedSpikeDecoder, ClusterlessSpikeDecoder class ModelLoader: def __init__(self): super().__init__() @classmethod - def load_sorted_spike_model(cls, + def load_sorted_spike_decoder(cls, model_path: str = "../../../datasets/decoder_data/sorted_spike_decoder.pkl") -> SortedSpikeDecoder: with open(model_path, "rb") as f: model = pickle.load(f) return SortedSpikeDecoder(model) @classmethod - def load_clusterless_spike_model(cls, + def load_clusterless_spike_decoder(cls, model_path: str = "../../../datasets/decoder_data/clusterless_spike_decoder.pkl") -> ClusterlessSpikeDecoder: with open(model_path, "rb") as f: model = pickle.load(f) From ebd1f6cb1dd67746f002523be4a70d438a7d249a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:54:06 +0000 Subject: [PATCH 12/34] Updated init file --- .../decoder/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py index 8378a5a..2163c67 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py @@ -1,4 +1,5 @@ -from .core import SortedSpikeDecoder, ClusterlessSpikeDecoder +from .decoder import SortedSpikeDecoder, ClusterlessSpikeDecoder from .data_loader import DataLoader from .data_iterator import DataIterator -from .model_loader import ModelLoader \ No newline at end of file +from .model_loader import ModelLoader +from .likelihood import LIKELIHOOD_FUNCTION \ No newline at end of file From 6cb45dc2da845f07026584b602498f5e989aa581 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 7 Nov 2024 11:13:51 +0000 Subject: [PATCH 13/34] Add bonsai shaders package --- .../.bonsai/Bonsai.config | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config index d09f619..0f6f004 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config @@ -14,12 +14,15 @@ + + + @@ -52,6 +55,7 @@ + @@ -68,12 +72,15 @@ + + + From e4d05458554dd723bca5bc95f4fb27e045395e00 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 7 Nov 2024 11:14:33 +0000 Subject: [PATCH 14/34] Updated to use shaders render frequency for timing --- .../ClusterlessSpikes.bonsai | 99 +++++++++++++------ 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai index a9175cf..42d4e33 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai +++ b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai @@ -1,6 +1,7 @@  + + + + + + 640 + 480 + On + false + Black + DepthBufferBit ColorBufferBit + true + + Resizable + Minimized + Primary + 50 + + + + + 8 + 8 + 8 + 8 + + 16 + 0 + 0 + + 0 + 0 + 0 + 0 + + 2 + false + + + @@ -38,7 +79,7 @@ from decoder import * - ModelLoader + LoadDecoder @@ -46,7 +87,7 @@ from decoder import * - model = ModelLoader.load_clusterless_spike_model() + decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl") @@ -58,7 +99,7 @@ from decoder import * - DataLoader + LoadData @@ -84,14 +125,8 @@ from decoder import * IterateData - - - - - PT0S - PT0.1S - + @@ -121,15 +156,14 @@ from decoder import * - - - - + + + + - @@ -149,7 +183,7 @@ from decoder import * - model.decode(spikes) + decoder.decode(spikes) @@ -291,38 +325,39 @@ from decoder import * - - - - + + + + - + - + - + - - - - + + + - - - - + + + + - + + + \ No newline at end of file From 8fce417128250aa3eefb2705b7046e6b163781ed Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 7 Nov 2024 14:07:53 +0000 Subject: [PATCH 15/34] Renamed folders more appropriately to NeuralDecoding/PositionDecodingFromHippocampus --- .../decoder/likelihood.py | 67 - .../decoder/model_loader.py | 20 - .../.bonsai/Bonsai.config | 0 .../.bonsai/NuGet.config | 0 .../.bonsai/Setup.ps1 | 0 .../.bonsai/Setup.sh | 0 .../.bonsai/activate | 0 .../.bonsai/deactivate | 0 .../.bonsai/run | 0 .../ClusterlessSpikes.bonsai | 4 +- .../Extensions.csproj | 0 .../Extensions/PositionBins.cs | 0 .../Extensions/Posterior.cs | 0 .../README.md | 0 .../SortedSpikes.bonsai | 2 +- .../decoder/__init__.py | 1 - .../decoder/data_iterator.py | 0 .../decoder/data_loader.py | 0 .../decoder/decoder.py | 54 +- .../decoder/likelihood.py | 159 ++ .../notebooks/ClusterlessDecoder.ipynb | 1388 +++++++++++++++++ .../notebooks/SortedSpikesDecoder.ipynb | 1219 +++++++++++++++ .../scripts/clusterless_spike_decoder.py | 138 ++ .../scripts/sorted_spike_decoder.py | 129 ++ 24 files changed, 3077 insertions(+), 104 deletions(-) delete mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py delete mode 100644 examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/Bonsai.config (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/NuGet.config (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/Setup.ps1 (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/Setup.sh (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/activate (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/deactivate (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/.bonsai/run (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/ClusterlessSpikes.bonsai (98%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/Extensions.csproj (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/Extensions/PositionBins.cs (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/Extensions/Posterior.cs (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/README.md (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/SortedSpikes.bonsai (99%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/decoder/__init__.py (82%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/decoder/data_iterator.py (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/decoder/data_loader.py (100%) rename examples/{Decoder/PositionDecodingFromHippocampalSortedSpikes => NeuralDecoding/PositionDecodingFromHippocampus}/decoder/decoder.py (74%) create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/clusterless_spike_decoder.py create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/sorted_spike_decoder.py diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py deleted file mode 100644 index 00aa9f6..0000000 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/likelihood.py +++ /dev/null @@ -1,67 +0,0 @@ -from replay_trajectory_classification.core import scaled_likelihood -from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity -from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import poisson_log_likelihood - -import numpy as np - -def spiking_likelihood_kde(spikes, conditional_intensity, is_track_interior): - - log_likelihood = 0 - for spike, ci in zip(spikes, conditional_intensity.T): - log_likelihood += poisson_log_likelihood(spike[np.newaxis], ci) - - mask = np.ones_like(is_track_interior, dtype=float) - mask[~is_track_interior] = np.nan - - likelihood = scaled_likelihood(log_likelihood * mask) - likelihood = likelihood[:, is_track_interior].astype(float) - - return likelihood - -def multiunit_likelihood(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): - log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) - - if not np.isnan(multiunits).all(): - multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] - - - for multiunit, enc_marks, enc_pos, mean_rate in zip( - multiunits.T, - encoding_marks, - encoding_positions, - mean_rates, - ): - is_spike = np.any(~np.isnan(multiunit)) - if is_spike: - decoding_marks = np.asarray( - multiunit, dtype=np.float32 - )[np.newaxis] - log_joint_mark_intensity = np.zeros( - (1, n_track_bins), dtype=np.float32 - ) - position_distance = estimate_position_distance( - interior_place_bin_centers, - np.asarray(enc_pos, dtype=np.float32), - position_std, - ).astype(np.float32) - log_joint_mark_intensity[0] = estimate_log_joint_mark_intensity( - decoding_marks, - enc_marks, - mark_std, - interior_occupancy, - mean_rate, - position_distance=position_distance, - ) - log_likelihood[:, is_track_interior] += np.nan_to_num( - log_joint_mark_intensity - ) - - log_likelihood[:, ~is_track_interior] = np.nan - likelihood = scaled_likelihood(log_likelihood) - - return likelihood - -LIKELIHOOD_FUNCTION = { - "multiunit_likelihood": multiunit_likelihood, - "spiking_likelihood_kde": spiking_likelihood_kde -} \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py b/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py deleted file mode 100644 index d972f7b..0000000 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/model_loader.py +++ /dev/null @@ -1,20 +0,0 @@ -import pickle -from .decoder import SortedSpikeDecoder, ClusterlessSpikeDecoder - -class ModelLoader: - def __init__(self): - super().__init__() - - @classmethod - def load_sorted_spike_decoder(cls, - model_path: str = "../../../datasets/decoder_data/sorted_spike_decoder.pkl") -> SortedSpikeDecoder: - with open(model_path, "rb") as f: - model = pickle.load(f) - return SortedSpikeDecoder(model) - - @classmethod - def load_clusterless_spike_decoder(cls, - model_path: str = "../../../datasets/decoder_data/clusterless_spike_decoder.pkl") -> ClusterlessSpikeDecoder: - with open(model_path, "rb") as f: - model = pickle.load(f) - return ClusterlessSpikeDecoder(model) \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Bonsai.config rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/NuGet.config b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/NuGet.config similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/NuGet.config rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/NuGet.config diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.ps1 b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.ps1 similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.ps1 rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.ps1 diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.sh b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.sh similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/Setup.sh rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.sh diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/activate b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/activate similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/activate rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/activate diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/deactivate b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/deactivate similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/deactivate rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/deactivate diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/run b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/run similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/.bonsai/run rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/run diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai similarity index 98% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai index 42d4e33..decf1d0 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/ClusterlessSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai @@ -26,7 +26,7 @@ Resizable Minimized Primary - 50 + 200 @@ -87,7 +87,7 @@ from decoder import * - decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl") + decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder.pkl") diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions.csproj b/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions.csproj similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions.csproj rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions.csproj diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/PositionBins.cs b/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/PositionBins.cs similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/PositionBins.cs rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/PositionBins.cs diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/Posterior.cs b/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/Posterior.cs similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/Extensions/Posterior.cs rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/Posterior.cs diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/README.md b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/README.md rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/SortedSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai similarity index 99% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/SortedSpikes.bonsai rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai index 572bc33..1d1d834 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/SortedSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai @@ -46,7 +46,7 @@ from decoder import * - model = ModelLoader.load_sorted_spike_model() + model = ModelLoader.load_sorted_spike_decoder() diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py similarity index 82% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py index 2163c67..d48ce44 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/__init__.py +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py @@ -1,5 +1,4 @@ from .decoder import SortedSpikeDecoder, ClusterlessSpikeDecoder from .data_loader import DataLoader from .data_iterator import DataIterator -from .model_loader import ModelLoader from .likelihood import LIKELIHOOD_FUNCTION \ No newline at end of file diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_iterator.py similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_iterator.py rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_iterator.py diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py similarity index 100% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/data_loader.py rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py diff --git a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/decoder.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py similarity index 74% rename from examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/decoder.py rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py index 1f659d0..f4bb48f 100644 --- a/examples/Decoder/PositionDecodingFromHippocampalSortedSpikes/decoder/decoder.py +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py @@ -8,13 +8,19 @@ import numpy as np import pickle as pkl +import cupy as cp class Decoder(): def __init__(self): super().__init__() - def decode(self): + def decode(self, data: np.ndarray): raise NotImplementedError + + @classmethod + def load(cls, filename: str): + with open(filename, "rb") as f: + return cls(pkl.load(f)) class ClusterlessSpikeDecoder(Decoder): def __init__(self, model_dict: dict): @@ -37,28 +43,46 @@ def __init__(self, model_dict: dict): self.place_bin_centers = self.decoder.environment.place_bin_centers_ self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order="F") self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior) - self.interior_place_bin_centers = np.asarray( - self.place_bin_centers[self.is_track_interior], dtype=np.float32 - ) - self.interior_occupancy = np.asarray( - self.occupancy[self.is_track_interior], dtype=np.float32 - ) + + self.likelihood_function = LIKELIHOOD_FUNCTION[self.decoder.clusterless_algorithm] + + if "gpu" in self.decoder.clusterless_algorithm: + self.is_track_interior_gpu = cp.asarray(self.is_track_interior) + self.occupancy = cp.asarray(self.occupancy) + self.interior_place_bin_centers = cp.asarray( + self.place_bin_centers[self.is_track_interior], dtype=cp.float32 + ) + self.interior_occupancy = cp.asarray( + self.occupancy[self.is_track_interior_gpu], dtype=cp.float32 + ) + + else: + self.is_track_interior_gpu = None + self.interior_place_bin_centers = np.asarray( + self.place_bin_centers[self.is_track_interior], dtype=np.float32 + ) + self.interior_occupancy = np.asarray( + self.occupancy[self.is_track_interior], dtype=np.float32 + ) + self.n_position_bins = self.is_track_interior.shape[0] self.n_track_bins = self.is_track_interior.sum() self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) - self.likelihood_funcion = LIKELIHOOD_FUNCTION[self.decoder.clusterless_algorithm] - self.posterior = None super().__init__() + + @classmethod + def load(cls, filename: str = "../../../datasets/decoder_data/clusterless_spike_decoder.pkl"): + return super().load(filename) def decode(self, - multiunits: np.ndarray): + data: np.ndarray): likelihood = self.likelihood_function( - multiunits, + data, self.summed_ground_process_intensity, self.encoding_marks, self.encoding_positions, @@ -103,12 +127,16 @@ def __init__(self, model_dict: dict): self.posterior = None super().__init__() + @classmethod + def load(cls, filename: str = "../../../datasets/decoder_data/sorted_spike_decoder.pkl"): + return super().load(filename) + def decode( self, - spikes: np.ndarray + data: np.ndarray ): - likelihood = self.likelihood_function(spikes, self.conditional_intensity, self.is_track_interior) + likelihood = self.likelihood_function(data, self.conditional_intensity, self.is_track_interior) if self.posterior is None: self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py new file mode 100644 index 0000000..ed71cd7 --- /dev/null +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py @@ -0,0 +1,159 @@ +from replay_trajectory_classification.core import scaled_likelihood +import replay_trajectory_classification.likelihoods.multiunit_likelihood as ml +import replay_trajectory_classification.likelihoods.multiunit_likelihood_gpu as mlgpu +import replay_trajectory_classification.likelihoods.multiunit_likelihood_integer_gpu as mligpu +from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import poisson_log_likelihood + +import numpy as np +import cupy as cp + +def spiking_likelihood_kde(spikes, conditional_intensity, is_track_interior): + + log_likelihood = 0 + for spike, ci in zip(spikes, conditional_intensity.T): + log_likelihood += poisson_log_likelihood(spike[np.newaxis], ci) + + mask = np.ones_like(is_track_interior, dtype=float) + mask[~is_track_interior] = np.nan + + likelihood = scaled_likelihood(log_likelihood * mask) + likelihood = likelihood[:, is_track_interior].astype(float) + + return likelihood + +def multiunit_likelihood(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): + log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) + + if not np.isnan(multiunits).all(): + # multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] + + for multiunit, enc_marks, enc_pos, mean_rate in zip( + multiunits.T, + encoding_marks, + encoding_positions, + mean_rates, + ): + is_spike = np.any(~np.isnan(multiunit)) + if is_spike: + decoding_marks = np.asarray( + multiunit, dtype=np.float32 + )[np.newaxis] + log_joint_mark_intensity = np.zeros( + (1, n_track_bins), dtype=np.float32 + ) + position_distance = ml.estimate_position_distance( + interior_place_bin_centers, + np.asarray(enc_pos, dtype=np.float32), + position_std, + ).astype(np.float32) + log_joint_mark_intensity[0] = ml.estimate_log_joint_mark_intensity( + decoding_marks, + enc_marks, + mark_std, + interior_occupancy, + mean_rate, + position_distance=position_distance, + ) + log_likelihood[:, is_track_interior] += np.nan_to_num( + log_joint_mark_intensity + ) + + log_likelihood[:, ~is_track_interior] = np.nan + likelihood = scaled_likelihood(log_likelihood) + + return likelihood + +def multiunit_likelihood_gpu(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): + log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) + + if not np.isnan(multiunits).all(): + # multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] + + for multiunit, enc_marks, enc_pos, mean_rate in zip( + multiunits.T, + encoding_marks, + encoding_positions, + mean_rates, + ): + is_spike = np.any(~np.isnan(multiunit)) + if is_spike: + decoding_marks = cp.asarray( + multiunit, dtype=cp.float32 + )[cp.newaxis] + log_joint_mark_intensity = np.zeros( + (1, n_track_bins), dtype=np.float32 + ) + position_distance = mlgpu.estimate_position_distance( + interior_place_bin_centers, + cp.asarray(enc_pos, dtype=cp.float32), + position_std, + ).astype(cp.float32) + log_joint_mark_intensity[0] = mlgpu.estimate_log_joint_mark_intensity( + decoding_marks, + enc_marks, + mark_std, + interior_occupancy, + mean_rate, + position_distance=position_distance, + ) + log_likelihood[:, is_track_interior] += np.nan_to_num( + log_joint_mark_intensity + ) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + + log_likelihood[:, ~is_track_interior] = np.nan + likelihood = scaled_likelihood(log_likelihood) + return likelihood + +def multiunit_likelihood_integer_gpu(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): + log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) + + if not np.isnan(multiunits).all(): + # multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] + + for multiunit, enc_marks, enc_pos, mean_rate in zip( + multiunits.T, + encoding_marks, + encoding_positions, + mean_rates, + ): + is_spike = np.any(~np.isnan(multiunit)) + if is_spike: + decoding_marks = cp.asarray( + multiunit, dtype=cp.int16 + )[cp.newaxis] + log_joint_mark_intensity = np.zeros( + (1, n_track_bins), dtype=np.float32 + ) + position_distance = mligpu.estimate_position_distance( + interior_place_bin_centers, + cp.asarray(enc_pos, dtype=cp.float32), + position_std, + ).astype(cp.float32) + log_joint_mark_intensity[0] = mligpu.estimate_log_joint_mark_intensity( + decoding_marks, + enc_marks, + mark_std, + interior_occupancy, + mean_rate, + position_distance=position_distance, + ) + log_likelihood[:, is_track_interior] += np.nan_to_num( + log_joint_mark_intensity + ) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + + log_likelihood[:, ~is_track_interior] = np.nan + likelihood = scaled_likelihood(log_likelihood) + return likelihood + +LIKELIHOOD_FUNCTION = { + "multiunit_likelihood": multiunit_likelihood, + "spiking_likelihood_kde": spiking_likelihood_kde, + "multiunit_likelihood_gpu": multiunit_likelihood_gpu, + "multiunit_likelihood_integer_gpu": multiunit_likelihood_integer_gpu +} \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb new file mode 100644 index 0000000..3976807 --- /dev/null +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb @@ -0,0 +1,1388 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Learn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/multiunit_likelihood.py:8: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from tqdm.autonotebook import tqdm\n" + ] + } + ], + "source": [ + "!export CUDA_PATH=\"/usr/local/cuda\"\n", + "# %set_env CUDA_PATH=\"/usr/local/cuda\"\n", + "import os\n", + "os.environ[\"CUDA_PATH\"] = \"/usr/local/cuda\"\n", + "\n", + "import numpy as np\n", + "import sys\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import pickle as pkl\n", + "\n", + "import replay_trajectory_classification as rtc\n", + "import track_linearization as tl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/usr/local/cuda'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from cupy.cuda import get_cuda_path\n", + "get_cuda_path()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", + "spikes_filename = \"../../../../datasets/decoder_data/clusterless_spike_times.pkl\"\n", + "features_filename = \"../../../../datasets/decoder_data/clusterless_spike_features.pkl\"\n", + "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl\"\n", + "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "positions_df = pd.read_pickle(positions_filename)\n", + "timestamps = positions_df.index.to_numpy()\n", + "dt = timestamps[1] - timestamps[0]\n", + "Fs = 1.0 / dt\n", + "spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nose_xnose_ynose_veltailBase_xtailBase_ytailBase_veltailMid_xtailMid_ytailMid_veltailTip_x...hindpawR_velforelimb_mid_xforelimb_mid_yforelimb_velbody_dirlinear_positiontrack_segment_idprojected_x_positionprojected_y_positionarm_name
time
22389.0828756.3026285.23117445.32077012.8560645.54749648.29090911.3679136.83302156.50818616.296467...NaN5.0806713.215346125.7070052.338472162.28501936.3451018.359380Left Arm
22389.0848756.7552765.64450345.57971613.5034485.68928648.48375112.1428277.31534556.75125617.387420...NaN5.1946843.302497124.9766022.347791162.73201436.7920558.353311Left Arm
22389.0868757.2079256.05783345.83866114.1508325.83107648.67659312.9177407.79766956.99432518.478373...NaN5.3086973.389648124.2462002.357110163.17901037.2390098.347243Left Arm
22389.0888757.6605736.47116246.09760714.7982155.97286648.86943513.6926548.27999357.23739519.569326...NaN5.4227093.476798123.5157972.366429163.62600537.6859638.341174Left Arm
22389.0908758.1132226.88449246.35655215.4455996.11465749.06227714.4675688.76231757.48046420.660279...NaN5.5367223.563949122.7853942.375748164.07300038.1329178.335106Left Arm
..................................................................
23293.7228750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7248750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7268750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7288750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7308750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
\n", + "

452325 rows × 33 columns

\n", + "
" + ], + "text/plain": [ + " nose_x nose_y nose_vel tailBase_x tailBase_y \\\n", + "time \n", + "22389.082875 6.302628 5.231174 45.320770 12.856064 5.547496 \n", + "22389.084875 6.755276 5.644503 45.579716 13.503448 5.689286 \n", + "22389.086875 7.207925 6.057833 45.838661 14.150832 5.831076 \n", + "22389.088875 7.660573 6.471162 46.097607 14.798215 5.972866 \n", + "22389.090875 8.113222 6.884492 46.356552 15.445599 6.114657 \n", + "... ... ... ... ... ... \n", + "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "\n", + " tailBase_vel tailMid_x tailMid_y tailMid_vel tailTip_x ... \\\n", + "time ... \n", + "22389.082875 48.290909 11.367913 6.833021 56.508186 16.296467 ... \n", + "22389.084875 48.483751 12.142827 7.315345 56.751256 17.387420 ... \n", + "22389.086875 48.676593 12.917740 7.797669 56.994325 18.478373 ... \n", + "22389.088875 48.869435 13.692654 8.279993 57.237395 19.569326 ... \n", + "22389.090875 49.062277 14.467568 8.762317 57.480464 20.660279 ... \n", + "... ... ... ... ... ... ... \n", + "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "\n", + " hindpawR_vel forelimb_mid_x forelimb_mid_y forelimb_vel \\\n", + "time \n", + "22389.082875 NaN 5.080671 3.215346 125.707005 \n", + "22389.084875 NaN 5.194684 3.302497 124.976602 \n", + "22389.086875 NaN 5.308697 3.389648 124.246200 \n", + "22389.088875 NaN 5.422709 3.476798 123.515797 \n", + "22389.090875 NaN 5.536722 3.563949 122.785394 \n", + "... ... ... ... ... \n", + "23293.722875 0.0 0.000000 0.000000 0.000000 \n", + "23293.724875 0.0 0.000000 0.000000 0.000000 \n", + "23293.726875 0.0 0.000000 0.000000 0.000000 \n", + "23293.728875 0.0 0.000000 0.000000 0.000000 \n", + "23293.730875 0.0 0.000000 0.000000 0.000000 \n", + "\n", + " body_dir linear_position track_segment_id \\\n", + "time \n", + "22389.082875 2.338472 162.285019 3 \n", + "22389.084875 2.347791 162.732014 3 \n", + "22389.086875 2.357110 163.179010 3 \n", + "22389.088875 2.366429 163.626005 3 \n", + "22389.090875 2.375748 164.073000 3 \n", + "... ... ... ... \n", + "23293.722875 -0.000057 161.447258 3 \n", + "23293.724875 -0.000057 161.447258 3 \n", + "23293.726875 -0.000057 161.447258 3 \n", + "23293.728875 -0.000057 161.447258 3 \n", + "23293.730875 -0.000057 161.447258 3 \n", + "\n", + " projected_x_position projected_y_position arm_name \n", + "time \n", + "22389.082875 6.345101 8.359380 Left Arm \n", + "22389.084875 6.792055 8.353311 Left Arm \n", + "22389.086875 7.239009 8.347243 Left Arm \n", + "22389.088875 7.685963 8.341174 Left Arm \n", + "22389.090875 8.132917 8.335106 Left Arm \n", + "... ... ... ... \n", + "23293.722875 5.507417 8.370753 Left Arm \n", + "23293.724875 5.507417 8.370753 Left Arm \n", + "23293.726875 5.507417 8.370753 Left Arm \n", + "23293.728875 5.507417 8.370753 Left Arm \n", + "23293.730875 5.507417 8.370753 Left Arm \n", + "\n", + "[452325 rows x 33 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "positions_df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "x = positions_df[\"nose_x\"].to_numpy()\n", + "y = positions_df[\"nose_y\"].to_numpy()\n", + "positions = np.column_stack((x, y))\n", + "node_positions = [(120.0, 100.0),\n", + " ( 5.0, 100.0),\n", + " ( 5.0, 55.0),\n", + " (120.0, 55.0),\n", + " ( 5.0, 8.5),\n", + " (120.0, 8.5),\n", + " ]\n", + "edges = [\n", + " (3, 2),\n", + " (0, 1),\n", + " (1, 2),\n", + " (5, 4),\n", + " (4, 2),\n", + " ]\n", + "track_graph = rtc.make_track_graph(node_positions, edges)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "edge_order = [\n", + " (3, 2),\n", + " (0, 1),\n", + " (1, 2),\n", + " (5, 4),\n", + " (4, 2),\n", + " ]\n", + "\n", + "edge_spacing = [16, 0, 16, 0]\n", + "\n", + "linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "with open(features_filename, \"rb\") as f:\n", + " clusterless_spike_features = pkl.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "with open(spikes_filename, \"rb\") as f:\n", + " clusterless_spike_times = pkl.load(f)\n", + "\n", + "features = np.ones((len(timestamps), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n", + "for n in range(len(clusterless_spike_times)):\n", + " in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins)\n", + " features[in_spikes_window, :, n] = clusterless_spike_features[n]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "place_bin_size = 0.5\n", + "movement_var = 0.25\n", + "\n", + "environment = rtc.Environment(place_bin_size=place_bin_size,\n", + " track_graph=track_graph,\n", + " edge_order=edge_order,\n", + " edge_spacing=edge_spacing)\n", + "\n", + "transition_type = rtc.RandomWalk(movement_var=movement_var)\n", + "\n", + "decoder = rtc.ClusterlessDecoder(\n", + " environment=environment,\n", + " transition_type=transition_type,\n", + " clusterless_algorithm=\"multiunit_likelihood_integer_gpu\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Learning model parameters\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/continuous_state_transitions.py:24: RuntimeWarning: invalid value encountered in divide\n", + " x /= x.sum(axis=1, keepdims=True)\n" + ] + }, + { + "data": { + "text/html": [ + "
ClusterlessDecoder(clusterless_algorithm='multiunit_likelihood_integer_gpu',\n",
+       "                   clusterless_algorithm_params={'mark_std': 24.0,\n",
+       "                                                 'position_std': 6.0},\n",
+       "                   environment=Environment(environment_name='',\n",
+       "                                           place_bin_size=0.5,\n",
+       "                                           track_graph=<networkx.classes.graph.Graph object at 0x73b71c0ee570>,\n",
+       "                                           edge_order=[(3, 2), (0, 1), (1, 2),\n",
+       "                                                       (5, 4), (4, 2)],\n",
+       "                                           edge_spacing=[16, 0, 16, 0],\n",
+       "                                           is_track_interior=None,\n",
+       "                                           position_range=None,\n",
+       "                                           infer_track_interior=True,\n",
+       "                                           fill_holes=False,\n",
+       "                                           dilate=False,\n",
+       "                                           bin_count_threshold=0),\n",
+       "                   infer_track_interior=True,\n",
+       "                   initial_conditions_type=UniformInitialConditions(),\n",
+       "                   transition_type=RandomWalk(environment_name='',\n",
+       "                                              movement_var=0.25,\n",
+       "                                              movement_mean=0.0,\n",
+       "                                              use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "ClusterlessDecoder(clusterless_algorithm='multiunit_likelihood_integer_gpu',\n", + " clusterless_algorithm_params={'mark_std': 24.0,\n", + " 'position_std': 6.0},\n", + " environment=Environment(environment_name='',\n", + " place_bin_size=0.5,\n", + " track_graph=,\n", + " edge_order=[(3, 2), (0, 1), (1, 2),\n", + " (5, 4), (4, 2)],\n", + " edge_spacing=[16, 0, 16, 0],\n", + " is_track_interior=None,\n", + " position_range=None,\n", + " infer_track_interior=True,\n", + " fill_holes=False,\n", + " dilate=False,\n", + " bin_count_threshold=0),\n", + " infer_track_interior=True,\n", + " initial_conditions_type=UniformInitialConditions(),\n", + " transition_type=RandomWalk(environment_name='',\n", + " movement_var=0.25,\n", + " movement_mean=0.0,\n", + " use_diffusion=False))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Learning model parameters\")\n", + "decoder.fit(linearized_positions.linear_position, features)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving model to ../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl\n" + ] + } + ], + "source": [ + "print(f\"Saving model to {model_filename}\")\n", + "\n", + "results = dict(decoder=decoder, linearized_positions=linearized_positions,\n", + " clusterless_spike_times=clusterless_spike_times, features=features, Fs=Fs)\n", + "\n", + "with open(model_filename, \"wb\") as f:\n", + " pkl.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decode" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "decoding_start_secs = 0\n", + "decoding_duration_secs = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "with open(model_filename, \"rb\") as f:\n", + " model_results = pkl.load(f)\n", + " \n", + "decoder = model_results[\"decoder\"]\n", + "Fs = model_results[\"Fs\"]\n", + "clusterless_spike_times = model_results[\"clusterless_spike_times\"]\n", + "features = model_results[\"features\"]\n", + "linearized_positions = model_results[\"linearized_positions\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding positions from features\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "n_electrodes: 100%|██████████| 28/28 [00:10<00:00, 2.55it/s]\n", + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/core.py:84: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'F', False, aligned=True), Array(float64, 1, 'A', False, aligned=True))\n", + " posterior[k] = state_transition.T @ posterior[k - 1] * likelihood[k]\n", + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/core.py:116: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'F', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n", + " acausal_prior = state_transition.T @ causal_posterior[time_ind]\n" + ] + } + ], + "source": [ + "print(\"Decoding positions from features\")\n", + "decoding_start_samples = int(decoding_start_secs * Fs)\n", + "decoding_duration_samples = int(decoding_duration_secs * Fs)\n", + "time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)\n", + "time = np.arange(linearized_positions.linear_position.size) / Fs\n", + "decoding_results = decoder.predict(features[time_ind], time=time[time_ind])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving decoding results to ../../../../datasets/decoder_data/clusterless_decoding_results.pkl\n" + ] + } + ], + "source": [ + "print(f\"Saving decoding results to {decoding_filename}\")\n", + "\n", + "results = dict(decoding_results=decoding_results, time=time[time_ind],\n", + " linearized_positions=linearized_positions.iloc[time_ind],\n", + " features=features[time_ind])\n", + "\n", + "with open(decoding_filename, \"wb\") as f:\n", + " pkl.dump(results, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import replay_trajectory_classification as rtc\n", + "from replay_trajectory_classification.core import scaled_likelihood, get_centers\n", + "from replay_trajectory_classification.likelihoods import _SORTED_SPIKES_ALGORITHMS, _ClUSTERLESS_ALGORITHMS\n", + "from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import combined_likelihood, poisson_log_likelihood\n", + "from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity\n", + "\n", + "class ClusterlessSpikeDecoder:\n", + " def __init__(self, model_dict: dict):\n", + " self.decoder = model_dict[\"decoder\"]\n", + " self.Fs = model_dict[\"Fs\"]\n", + " self.features = model_dict[\"features\"]\n", + "\n", + " encoding_model = self.decoder.encoding_model_\n", + " self.encoding_marks = encoding_model[\"encoding_marks\"]\n", + " self.mark_std = encoding_model[\"mark_std\"]\n", + " self.encoding_positions = encoding_model[\"encoding_positions\"]\n", + " self.position_std = encoding_model[\"position_std\"]\n", + " self.occupancy = encoding_model[\"occupancy\"]\n", + " self.mean_rates = encoding_model[\"mean_rates\"]\n", + " self.summed_ground_process_intensity = encoding_model[\"summed_ground_process_intensity\"]\n", + " self.block_size = encoding_model[\"block_size\"]\n", + " self.bin_diffusion_distances = encoding_model[\"bin_diffusion_distances\"]\n", + " self.edges = encoding_model[\"edges\"]\n", + "\n", + " self.place_bin_centers = self.decoder.environment.place_bin_centers_\n", + " self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order=\"F\")\n", + " self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior)\n", + " self.interior_place_bin_centers = np.asarray(\n", + " self.place_bin_centers[self.is_track_interior], dtype=np.float32\n", + " )\n", + " self.interior_occupancy = np.asarray(\n", + " self.occupancy[self.is_track_interior], dtype=np.float32\n", + " )\n", + " self.n_position_bins = self.is_track_interior.shape[0]\n", + " self.n_track_bins = self.is_track_interior.sum()\n", + "\n", + " self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float)\n", + " self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float)\n", + "\n", + " self.posterior = None\n", + " super().__init__()\n", + " \n", + " def decode(self,\n", + " multiunits: np.ndarray):\n", + "\n", + " log_likelihood = -self.summed_ground_process_intensity * np.ones((1,1), dtype=np.float32)\n", + "\n", + " if not np.isnan(multiunits).all():\n", + " multiunit_idxs = np.where(~np.isnan(multiunits).all(axis=0))[0]\n", + " decoding_marks = np.asarray(\n", + " multiunits[:, multiunit_idxs], dtype=np.float32\n", + " )[np.newaxis]\n", + " log_joint_mark_intensity = np.zeros(\n", + " (1, self.n_track_bins), dtype=np.float32\n", + " )\n", + " position_distance = estimate_position_distance(\n", + " self.interior_place_bin_centers,\n", + " np.asarray(enc_pos, dtype=np.float32),\n", + " self.position_std,\n", + " ).astype(np.float32)\n", + " \n", + " for multiunit, enc_marks, enc_pos, mean_rate in zip(\n", + " multiunits.T,\n", + " self.encoding_marks,\n", + " self.encoding_positions,\n", + " self.mean_rates,\n", + " ):\n", + " is_spike = np.any(~np.isnan(multiunit))\n", + " if is_spike:\n", + " decoding_marks = np.asarray(\n", + " multiunit, dtype=np.float32\n", + " )[np.newaxis]\n", + " log_joint_mark_intensity = np.zeros(\n", + " (1, self.n_track_bins), dtype=np.float32\n", + " )\n", + " position_distance = estimate_position_distance(\n", + " self.interior_place_bin_centers,\n", + " np.asarray(enc_pos, dtype=np.float32),\n", + " self.position_std,\n", + " ).astype(np.float32)\n", + " log_joint_mark_intensity[0] = estimate_log_joint_mark_intensity(\n", + " decoding_marks,\n", + " enc_marks,\n", + " self.mark_std,\n", + " self.interior_occupancy,\n", + " mean_rate,\n", + " position_distance=position_distance,\n", + " )\n", + " log_likelihood[:, self.is_track_interior] += np.nan_to_num(\n", + " log_joint_mark_intensity\n", + " )\n", + " \n", + " log_likelihood[:, ~self.is_track_interior] = np.nan\n", + " likelihood = scaled_likelihood(log_likelihood)\n", + "\n", + " if self.posterior is None:\n", + " self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float)\n", + " self.posterior[0, self.is_track_interior] = self.initial_conditions * likelihood[0, self.is_track_interior]\n", + "\n", + " else:\n", + " self.posterior[0, self.is_track_interior] = self.state_transition.T @ self.posterior[0, self.is_track_interior] * likelihood[0, self.is_track_interior]\n", + "\n", + " norm = np.nansum(self.posterior[0])\n", + " self.posterior[0] /= norm\n", + "\n", + " return self.posterior" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "UnboundLocalError", + "evalue": "cannot access local variable 'enc_pos' where it is not associated with a value", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mUnboundLocalError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m decoder \u001b[38;5;241m=\u001b[39m ClusterlessSpikeDecoder(model_results)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mdecoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[11], line 58\u001b[0m, in \u001b[0;36mClusterlessSpikeDecoder.decode\u001b[0;34m(self, multiunits)\u001b[0m\n\u001b[1;32m 50\u001b[0m decoding_marks \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39masarray(\n\u001b[1;32m 51\u001b[0m multiunits[:, multiunit_idxs], dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32\n\u001b[1;32m 52\u001b[0m )[np\u001b[38;5;241m.\u001b[39mnewaxis]\n\u001b[1;32m 53\u001b[0m log_joint_mark_intensity \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\n\u001b[1;32m 54\u001b[0m (\u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_track_bins), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32\n\u001b[1;32m 55\u001b[0m )\n\u001b[1;32m 56\u001b[0m position_distance \u001b[38;5;241m=\u001b[39m estimate_position_distance(\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minterior_place_bin_centers,\n\u001b[0;32m---> 58\u001b[0m np\u001b[38;5;241m.\u001b[39masarray(\u001b[43menc_pos\u001b[49m, dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32),\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mposition_std,\n\u001b[1;32m 60\u001b[0m )\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m multiunit, enc_marks, enc_pos, mean_rate \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\n\u001b[1;32m 63\u001b[0m multiunits\u001b[38;5;241m.\u001b[39mT,\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoding_marks,\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoding_positions,\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmean_rates,\n\u001b[1;32m 67\u001b[0m ):\n\u001b[1;32m 68\u001b[0m is_spike \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39many(\u001b[38;5;241m~\u001b[39mnp\u001b[38;5;241m.\u001b[39misnan(multiunit))\n", + "\u001b[0;31mUnboundLocalError\u001b[0m: cannot access local variable 'enc_pos' where it is not associated with a value" + ] + } + ], + "source": [ + "decoder = ClusterlessSpikeDecoder(model_results)\n", + "decoder.decode(features[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb new file mode 100644 index 0000000..6979c7e --- /dev/null +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb @@ -0,0 +1,1219 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Learn" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/multiunit_likelihood.py:8: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from tqdm.autonotebook import tqdm\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import sys\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import pickle as pkl\n", + "\n", + "import replay_trajectory_classification as rtc\n", + "import track_linearization as tl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", + "spikes_filename = \"../../../../datasets/decoder_data/sorted_spike_times.pkl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "positions_df = pd.read_pickle(positions_filename)\n", + "timestamps = positions_df.index.to_numpy()\n", + "dt = timestamps[1] - timestamps[0]\n", + "Fs = 1.0 / dt\n", + "spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nose_xnose_ynose_veltailBase_xtailBase_ytailBase_veltailMid_xtailMid_ytailMid_veltailTip_x...hindpawR_velforelimb_mid_xforelimb_mid_yforelimb_velbody_dirlinear_positiontrack_segment_idprojected_x_positionprojected_y_positionarm_name
time
22389.0828756.3026285.23117445.32077012.8560645.54749648.29090911.3679136.83302156.50818616.296467...NaN5.0806713.215346125.7070052.338472162.28501936.3451018.359380Left Arm
22389.0848756.7552765.64450345.57971613.5034485.68928648.48375112.1428277.31534556.75125617.387420...NaN5.1946843.302497124.9766022.347791162.73201436.7920558.353311Left Arm
22389.0868757.2079256.05783345.83866114.1508325.83107648.67659312.9177407.79766956.99432518.478373...NaN5.3086973.389648124.2462002.357110163.17901037.2390098.347243Left Arm
22389.0888757.6605736.47116246.09760714.7982155.97286648.86943513.6926548.27999357.23739519.569326...NaN5.4227093.476798123.5157972.366429163.62600537.6859638.341174Left Arm
22389.0908758.1132226.88449246.35655215.4455996.11465749.06227714.4675688.76231757.48046420.660279...NaN5.5367223.563949122.7853942.375748164.07300038.1329178.335106Left Arm
..................................................................
23293.7228750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7248750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7268750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7288750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7308750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
\n", + "

452325 rows × 33 columns

\n", + "
" + ], + "text/plain": [ + " nose_x nose_y nose_vel tailBase_x tailBase_y \\\n", + "time \n", + "22389.082875 6.302628 5.231174 45.320770 12.856064 5.547496 \n", + "22389.084875 6.755276 5.644503 45.579716 13.503448 5.689286 \n", + "22389.086875 7.207925 6.057833 45.838661 14.150832 5.831076 \n", + "22389.088875 7.660573 6.471162 46.097607 14.798215 5.972866 \n", + "22389.090875 8.113222 6.884492 46.356552 15.445599 6.114657 \n", + "... ... ... ... ... ... \n", + "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "\n", + " tailBase_vel tailMid_x tailMid_y tailMid_vel tailTip_x ... \\\n", + "time ... \n", + "22389.082875 48.290909 11.367913 6.833021 56.508186 16.296467 ... \n", + "22389.084875 48.483751 12.142827 7.315345 56.751256 17.387420 ... \n", + "22389.086875 48.676593 12.917740 7.797669 56.994325 18.478373 ... \n", + "22389.088875 48.869435 13.692654 8.279993 57.237395 19.569326 ... \n", + "22389.090875 49.062277 14.467568 8.762317 57.480464 20.660279 ... \n", + "... ... ... ... ... ... ... \n", + "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", + "\n", + " hindpawR_vel forelimb_mid_x forelimb_mid_y forelimb_vel \\\n", + "time \n", + "22389.082875 NaN 5.080671 3.215346 125.707005 \n", + "22389.084875 NaN 5.194684 3.302497 124.976602 \n", + "22389.086875 NaN 5.308697 3.389648 124.246200 \n", + "22389.088875 NaN 5.422709 3.476798 123.515797 \n", + "22389.090875 NaN 5.536722 3.563949 122.785394 \n", + "... ... ... ... ... \n", + "23293.722875 0.0 0.000000 0.000000 0.000000 \n", + "23293.724875 0.0 0.000000 0.000000 0.000000 \n", + "23293.726875 0.0 0.000000 0.000000 0.000000 \n", + "23293.728875 0.0 0.000000 0.000000 0.000000 \n", + "23293.730875 0.0 0.000000 0.000000 0.000000 \n", + "\n", + " body_dir linear_position track_segment_id \\\n", + "time \n", + "22389.082875 2.338472 162.285019 3 \n", + "22389.084875 2.347791 162.732014 3 \n", + "22389.086875 2.357110 163.179010 3 \n", + "22389.088875 2.366429 163.626005 3 \n", + "22389.090875 2.375748 164.073000 3 \n", + "... ... ... ... \n", + "23293.722875 -0.000057 161.447258 3 \n", + "23293.724875 -0.000057 161.447258 3 \n", + "23293.726875 -0.000057 161.447258 3 \n", + "23293.728875 -0.000057 161.447258 3 \n", + "23293.730875 -0.000057 161.447258 3 \n", + "\n", + " projected_x_position projected_y_position arm_name \n", + "time \n", + "22389.082875 6.345101 8.359380 Left Arm \n", + "22389.084875 6.792055 8.353311 Left Arm \n", + "22389.086875 7.239009 8.347243 Left Arm \n", + "22389.088875 7.685963 8.341174 Left Arm \n", + "22389.090875 8.132917 8.335106 Left Arm \n", + "... ... ... ... \n", + "23293.722875 5.507417 8.370753 Left Arm \n", + "23293.724875 5.507417 8.370753 Left Arm \n", + "23293.726875 5.507417 8.370753 Left Arm \n", + "23293.728875 5.507417 8.370753 Left Arm \n", + "23293.730875 5.507417 8.370753 Left Arm \n", + "\n", + "[452325 rows x 33 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "positions_df" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "x = positions_df[\"nose_x\"].to_numpy()\n", + "y = positions_df[\"nose_y\"].to_numpy()\n", + "positions = np.column_stack((x, y))\n", + "node_positions = [(120.0, 100.0),\n", + " ( 5.0, 100.0),\n", + " ( 5.0, 55.0),\n", + " (120.0, 55.0),\n", + " ( 5.0, 8.5),\n", + " (120.0, 8.5),\n", + " ]\n", + "edges = [\n", + " (3, 2),\n", + " (0, 1),\n", + " (1, 2),\n", + " (5, 4),\n", + " (4, 2),\n", + " ]\n", + "track_graph = rtc.make_track_graph(node_positions, edges)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "edge_order = [\n", + " (3, 2),\n", + " (0, 1),\n", + " (1, 2),\n", + " (5, 4),\n", + " (4, 2),\n", + " ]\n", + "\n", + "edge_spacing = [16, 0, 16, 0]\n", + "\n", + "linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "with open(spikes_filename, \"rb\") as f:\n", + " sorted_spike_times = pkl.load(f)\n", + "\n", + "binned_spikes_times = np.empty((len(timestamps), len(sorted_spike_times)), dtype=float)\n", + "for n in range(len(sorted_spike_times)):\n", + " binned_spikes_times[:, n] = np.histogram(sorted_spike_times[n], spikes_bins)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "place_bin_size = 0.5\n", + "movement_var = 0.25\n", + "\n", + "environment = rtc.Environment(place_bin_size=place_bin_size,\n", + " track_graph=track_graph,\n", + " edge_order=edge_order,\n", + " edge_spacing=edge_spacing)\n", + "\n", + "transition_type = rtc.RandomWalk(movement_var=movement_var)\n", + "\n", + "decoder = rtc.SortedSpikesDecoder(\n", + " environment=environment,\n", + " transition_type=transition_type,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Learning model parameters\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/continuous_state_transitions.py:24: RuntimeWarning: invalid value encountered in divide\n", + " x /= x.sum(axis=1, keepdims=True)\n", + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/spiking_likelihood_kde.py:117: RuntimeWarning: divide by zero encountered in log\n", + " return np.exp(np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy))\n", + "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/spiking_likelihood_kde.py:117: RuntimeWarning: invalid value encountered in subtract\n", + " return np.exp(np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy))\n" + ] + }, + { + "data": { + "text/html": [ + "
SortedSpikesDecoder(environment=Environment(environment_name='',\n",
+       "                                            place_bin_size=0.5,\n",
+       "                                            track_graph=<networkx.classes.graph.Graph object at 0x7a3ee3555d30>,\n",
+       "                                            edge_order=[(3, 2), (0, 1), (1, 2),\n",
+       "                                                        (5, 4), (4, 2)],\n",
+       "                                            edge_spacing=[16, 0, 16, 0],\n",
+       "                                            is_track_interior=None,\n",
+       "                                            position_range=None,\n",
+       "                                            infer_track_interior=True,\n",
+       "                                            fill_holes=False,\n",
+       "                                            dilate=False,\n",
+       "                                            bin_count_threshold=0),\n",
+       "                    infer_track_interior=True,\n",
+       "                    initial_conditions_type=UniformInitialConditions(),\n",
+       "                    sorted_spikes_algorithm='spiking_likelihood_kde',\n",
+       "                    sorted_spikes_algorithm_params={'block_size': None,\n",
+       "                                                    'position_std': 6.0,\n",
+       "                                                    'use_diffusion': False},\n",
+       "                    transition_type=RandomWalk(environment_name='',\n",
+       "                                               movement_var=0.25,\n",
+       "                                               movement_mean=0.0,\n",
+       "                                               use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "SortedSpikesDecoder(environment=Environment(environment_name='',\n", + " place_bin_size=0.5,\n", + " track_graph=,\n", + " edge_order=[(3, 2), (0, 1), (1, 2),\n", + " (5, 4), (4, 2)],\n", + " edge_spacing=[16, 0, 16, 0],\n", + " is_track_interior=None,\n", + " position_range=None,\n", + " infer_track_interior=True,\n", + " fill_holes=False,\n", + " dilate=False,\n", + " bin_count_threshold=0),\n", + " infer_track_interior=True,\n", + " initial_conditions_type=UniformInitialConditions(),\n", + " sorted_spikes_algorithm='spiking_likelihood_kde',\n", + " sorted_spikes_algorithm_params={'block_size': None,\n", + " 'position_std': 6.0,\n", + " 'use_diffusion': False},\n", + " transition_type=RandomWalk(environment_name='',\n", + " movement_var=0.25,\n", + " movement_mean=0.0,\n", + " use_diffusion=False))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Learning model parameters\")\n", + "decoder.fit(linearized_positions.linear_position, binned_spikes_times)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving model to ../../../../datasets/decoder_data/sorted_spike_model.pkl\n" + ] + } + ], + "source": [ + "model_filename = \"../../../../datasets/decoder_data/sorted_spike_model.pkl\"\n", + "\n", + "print(f\"Saving model to {model_filename}\")\n", + "\n", + "results = dict(decoder=decoder, linearized_positions=linearized_positions,\n", + " binned_spikes_times=binned_spikes_times, Fs=Fs)\n", + "\n", + "with open(model_filename, \"wb\") as f:\n", + " pkl.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "decoding_start_secs = 0\n", + "decoding_duration_secs = 100\n", + "decoding_filename = \"../../../../datasets/decoder_data/sorted_spike_decoding_results.pkl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "with open(model_filename, \"rb\") as f:\n", + " model_results = pkl.load(f)\n", + " \n", + "decoder = model_results[\"decoder\"]\n", + "Fs = model_results[\"Fs\"]\n", + "binned_spikes_times = model_results[\"binned_spikes_times\"]\n", + "linearized_positions = model_results[\"linearized_positions\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding positions from spikes\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/104 [00:00 Date: Thu, 7 Nov 2024 14:09:20 +0000 Subject: [PATCH 16/34] Updated to remove cuda path declaration and copy of decoder class --- .../notebooks/ClusterlessDecoder.ipynb | 167 +----------------- 1 file changed, 2 insertions(+), 165 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb index 3976807..d1459a7 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb @@ -22,11 +22,6 @@ } ], "source": [ - "!export CUDA_PATH=\"/usr/local/cuda\"\n", - "# %set_env CUDA_PATH=\"/usr/local/cuda\"\n", - "import os\n", - "os.environ[\"CUDA_PATH\"] = \"/usr/local/cuda\"\n", - "\n", "import numpy as np\n", "import sys\n", "import pandas as pd\n", @@ -39,35 +34,14 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'/usr/local/cuda'" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from cupy.cuda import get_cuda_path\n", - "get_cuda_path()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", "spikes_filename = \"../../../../datasets/decoder_data/clusterless_spike_times.pkl\"\n", "features_filename = \"../../../../datasets/decoder_data/clusterless_spike_features.pkl\"\n", - "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl\"\n", + "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder.pkl\"\n", "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl\"" ] }, @@ -1219,143 +1193,6 @@ " pkl.dump(results, f)" ] }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "import replay_trajectory_classification as rtc\n", - "from replay_trajectory_classification.core import scaled_likelihood, get_centers\n", - "from replay_trajectory_classification.likelihoods import _SORTED_SPIKES_ALGORITHMS, _ClUSTERLESS_ALGORITHMS\n", - "from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import combined_likelihood, poisson_log_likelihood\n", - "from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity\n", - "\n", - "class ClusterlessSpikeDecoder:\n", - " def __init__(self, model_dict: dict):\n", - " self.decoder = model_dict[\"decoder\"]\n", - " self.Fs = model_dict[\"Fs\"]\n", - " self.features = model_dict[\"features\"]\n", - "\n", - " encoding_model = self.decoder.encoding_model_\n", - " self.encoding_marks = encoding_model[\"encoding_marks\"]\n", - " self.mark_std = encoding_model[\"mark_std\"]\n", - " self.encoding_positions = encoding_model[\"encoding_positions\"]\n", - " self.position_std = encoding_model[\"position_std\"]\n", - " self.occupancy = encoding_model[\"occupancy\"]\n", - " self.mean_rates = encoding_model[\"mean_rates\"]\n", - " self.summed_ground_process_intensity = encoding_model[\"summed_ground_process_intensity\"]\n", - " self.block_size = encoding_model[\"block_size\"]\n", - " self.bin_diffusion_distances = encoding_model[\"bin_diffusion_distances\"]\n", - " self.edges = encoding_model[\"edges\"]\n", - "\n", - " self.place_bin_centers = self.decoder.environment.place_bin_centers_\n", - " self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order=\"F\")\n", - " self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior)\n", - " self.interior_place_bin_centers = np.asarray(\n", - " self.place_bin_centers[self.is_track_interior], dtype=np.float32\n", - " )\n", - " self.interior_occupancy = np.asarray(\n", - " self.occupancy[self.is_track_interior], dtype=np.float32\n", - " )\n", - " self.n_position_bins = self.is_track_interior.shape[0]\n", - " self.n_track_bins = self.is_track_interior.sum()\n", - "\n", - " self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float)\n", - " self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float)\n", - "\n", - " self.posterior = None\n", - " super().__init__()\n", - " \n", - " def decode(self,\n", - " multiunits: np.ndarray):\n", - "\n", - " log_likelihood = -self.summed_ground_process_intensity * np.ones((1,1), dtype=np.float32)\n", - "\n", - " if not np.isnan(multiunits).all():\n", - " multiunit_idxs = np.where(~np.isnan(multiunits).all(axis=0))[0]\n", - " decoding_marks = np.asarray(\n", - " multiunits[:, multiunit_idxs], dtype=np.float32\n", - " )[np.newaxis]\n", - " log_joint_mark_intensity = np.zeros(\n", - " (1, self.n_track_bins), dtype=np.float32\n", - " )\n", - " position_distance = estimate_position_distance(\n", - " self.interior_place_bin_centers,\n", - " np.asarray(enc_pos, dtype=np.float32),\n", - " self.position_std,\n", - " ).astype(np.float32)\n", - " \n", - " for multiunit, enc_marks, enc_pos, mean_rate in zip(\n", - " multiunits.T,\n", - " self.encoding_marks,\n", - " self.encoding_positions,\n", - " self.mean_rates,\n", - " ):\n", - " is_spike = np.any(~np.isnan(multiunit))\n", - " if is_spike:\n", - " decoding_marks = np.asarray(\n", - " multiunit, dtype=np.float32\n", - " )[np.newaxis]\n", - " log_joint_mark_intensity = np.zeros(\n", - " (1, self.n_track_bins), dtype=np.float32\n", - " )\n", - " position_distance = estimate_position_distance(\n", - " self.interior_place_bin_centers,\n", - " np.asarray(enc_pos, dtype=np.float32),\n", - " self.position_std,\n", - " ).astype(np.float32)\n", - " log_joint_mark_intensity[0] = estimate_log_joint_mark_intensity(\n", - " decoding_marks,\n", - " enc_marks,\n", - " self.mark_std,\n", - " self.interior_occupancy,\n", - " mean_rate,\n", - " position_distance=position_distance,\n", - " )\n", - " log_likelihood[:, self.is_track_interior] += np.nan_to_num(\n", - " log_joint_mark_intensity\n", - " )\n", - " \n", - " log_likelihood[:, ~self.is_track_interior] = np.nan\n", - " likelihood = scaled_likelihood(log_likelihood)\n", - "\n", - " if self.posterior is None:\n", - " self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float)\n", - " self.posterior[0, self.is_track_interior] = self.initial_conditions * likelihood[0, self.is_track_interior]\n", - "\n", - " else:\n", - " self.posterior[0, self.is_track_interior] = self.state_transition.T @ self.posterior[0, self.is_track_interior] * likelihood[0, self.is_track_interior]\n", - "\n", - " norm = np.nansum(self.posterior[0])\n", - " self.posterior[0] /= norm\n", - "\n", - " return self.posterior" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "ename": "UnboundLocalError", - "evalue": "cannot access local variable 'enc_pos' where it is not associated with a value", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnboundLocalError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m decoder \u001b[38;5;241m=\u001b[39m ClusterlessSpikeDecoder(model_results)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mdecoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[11], line 58\u001b[0m, in \u001b[0;36mClusterlessSpikeDecoder.decode\u001b[0;34m(self, multiunits)\u001b[0m\n\u001b[1;32m 50\u001b[0m decoding_marks \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39masarray(\n\u001b[1;32m 51\u001b[0m multiunits[:, multiunit_idxs], dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32\n\u001b[1;32m 52\u001b[0m )[np\u001b[38;5;241m.\u001b[39mnewaxis]\n\u001b[1;32m 53\u001b[0m log_joint_mark_intensity \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\n\u001b[1;32m 54\u001b[0m (\u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_track_bins), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32\n\u001b[1;32m 55\u001b[0m )\n\u001b[1;32m 56\u001b[0m position_distance \u001b[38;5;241m=\u001b[39m estimate_position_distance(\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minterior_place_bin_centers,\n\u001b[0;32m---> 58\u001b[0m np\u001b[38;5;241m.\u001b[39masarray(\u001b[43menc_pos\u001b[49m, dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32),\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mposition_std,\n\u001b[1;32m 60\u001b[0m )\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m multiunit, enc_marks, enc_pos, mean_rate \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\n\u001b[1;32m 63\u001b[0m multiunits\u001b[38;5;241m.\u001b[39mT,\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoding_marks,\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoding_positions,\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmean_rates,\n\u001b[1;32m 67\u001b[0m ):\n\u001b[1;32m 68\u001b[0m is_spike \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39many(\u001b[38;5;241m~\u001b[39mnp\u001b[38;5;241m.\u001b[39misnan(multiunit))\n", - "\u001b[0;31mUnboundLocalError\u001b[0m: cannot access local variable 'enc_pos' where it is not associated with a value" - ] - } - ], - "source": [ - "decoder = ClusterlessSpikeDecoder(model_results)\n", - "decoder.decode(features[0])" - ] - }, { "cell_type": "code", "execution_count": null, From 97b753137da621e4428909b055fbd54c99a70610 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 7 Nov 2024 14:11:30 +0000 Subject: [PATCH 17/34] Updated to have filename declarations at top --- .../notebooks/SortedSpikesDecoder.ipynb | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb index 6979c7e..6bdc182 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb @@ -34,12 +34,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", - "spikes_filename = \"../../../../datasets/decoder_data/sorted_spike_times.pkl\"" + "spikes_filename = \"../../../../datasets/decoder_data/sorted_spike_times.pkl\"\n", + "model_filename = \"../../../../datasets/decoder_data/sorted_spike_decoder.pkl\"\n", + "decoding_filename = \"../../../../datasets/decoder_data/sorted_spike_decoding_results.pkl\"" ] }, { @@ -1074,7 +1076,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1086,8 +1088,6 @@ } ], "source": [ - "model_filename = \"../../../../datasets/decoder_data/sorted_spike_model.pkl\"\n", - "\n", "print(f\"Saving model to {model_filename}\")\n", "\n", "results = dict(decoder=decoder, linearized_positions=linearized_positions,\n", @@ -1111,8 +1111,7 @@ "outputs": [], "source": [ "decoding_start_secs = 0\n", - "decoding_duration_secs = 100\n", - "decoding_filename = \"../../../../datasets/decoder_data/sorted_spike_decoding_results.pkl\"" + "decoding_duration_secs = 100" ] }, { From 0b53a4a72c630da881ee30ac66655ea9a19a4522 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 7 Nov 2024 14:52:18 +0000 Subject: [PATCH 18/34] Updated --- .../ClusterlessSpikes.bonsai | 479 ++++++++++-------- .../SortedSpikes.bonsai | 418 +++++++++------ .../notebooks/ClusterlessDecoder.ipynb | 2 +- 3 files changed, 544 insertions(+), 355 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai index decf1d0..23426a9 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai @@ -1,55 +1,15 @@  - - - - - - 640 - 480 - On - false - Black - DepthBufferBit ColorBufferBit - true - - Resizable - Minimized - Primary - 200 - - - - - 8 - 8 - 8 - 8 - - 16 - 0 - 0 - - 0 - 0 - 0 - 0 - - 2 - false - - - @@ -87,7 +47,7 @@ from decoder import * - decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder.pkl") + decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl") @@ -125,8 +85,14 @@ from decoder import * IterateData + + + - + + PT0S + PT0.01S + @@ -156,150 +122,296 @@ from decoder import * - - - - + + + + + Data - - Data - - - Features - - - - spikes - - - - - decoder.decode(spikes) - - - - - - - OnlineDecoderResults - - - Data - - - OnlinePosterior - - - Data - - - Position - - - Convert.ToDouble(it.ToString()) - - - Position - - - Data - - - Decoding - - - - - - OfflineDecoderResults - - - Data - - - OfflinePosterior - - - Data - - - PositionBins - - - - - - PositionBins - - - PositionBins - - - OnlineDecoderResults - - - MapEstimate - - - - - - - OnlinePositionEstimate - - - PositionBins - - - OfflineDecoderResults - - - MapEstimate - - - - - - - OfflinePositionEstimate - - Visualizer + RunDecoder - OnlinePositionEstimate + Data - - - OfflinePositionEstimate + + Features + + + + spikes + + + + + decoder.decode(spikes) + + + + + + + OnlineDecoderResults + + + Data + + + OnlinePosterior - + Data + + + Position + + + Convert.ToDouble(it.ToString()) + + Position - - OnlinePosterior + Data - - + + Decoding + + + + + + OfflineDecoderResults + + + Data + + OfflinePosterior + + Data + + + PositionBins + + + + + + PositionBins + + + PositionBins + + + OnlineDecoderResults + + + MapEstimate + + + + + + + OnlinePositionEstimate + + + PositionBins + + + OfflineDecoderResults + + + MapEstimate + + + + + + + OfflinePositionEstimate + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + VisualizeDecoder + + + + PositionEstimates + + + + true + true + Online Prediction + + + + true + true + Offline Acausal Prediction + + + + true + true + True Position + + + + OnlinePositionEstimate + + + + OfflinePositionEstimate + + + + Position + + + + 3 + 2 + + + + Percent + 10 + + + Percent + 90 + + + + + + + + + + + + + + + + + + + + + + + + + + Posteriors + + + + true + true + Online Posterior + + + + true + true + Offline Acausal Posterior + + + + OnlinePosterior + + + + OfflinePosterior + + + + true + true + 2 + 2 + + + + Percent + 10 + + + Percent + 90 + + + + + + + + + + + + + + + + + + + true true - 3 + 1 2 @@ -309,55 +421,20 @@ from decoder import * - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai index 1d1d834..13efb12 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai @@ -6,6 +6,7 @@ xmlns:scr="clr-namespace:Bonsai.Scripting.Expressions;assembly=Bonsai.Scripting.Expressions" xmlns:p1="clr-namespace:;assembly=Extensions" xmlns:gui="clr-namespace:Bonsai.Gui;assembly=Bonsai.Gui" + xmlns:viz="clr-namespace:Bonsai.Design.Visualizers;assembly=Bonsai.Design.Visualizers" xmlns="https://bonsai-rx.org/2018/workflow"> @@ -38,7 +39,7 @@ from decoder import *
- ModelLoader + LoadDecoder @@ -46,7 +47,7 @@ from decoder import * - model = ModelLoader.load_sorted_spike_decoder() + model = SortedSpikeDecoder.load() @@ -58,7 +59,7 @@ from decoder import * - DataLoader + LoadData @@ -90,7 +91,7 @@ from decoder import * PT0S - PT0.1S + PT0.01S @@ -136,136 +137,281 @@ from decoder import * Data - - Data - - - Spikes - - - - spikes - - - - - model.decode(spikes) - - - - - - - OnlineDecoderResults - - - Data - - - OnlinePosterior - - - Data - - - Position - - - Convert.ToDouble(it.ToString()) - - - Position - - - Data - - - Decoding - - - - - - OfflineDecoderResults - - - Data - - - OfflinePosterior - - - Data - - - PositionBins - - - - - - PositionBins - - - PositionBins - - - OnlineDecoderResults - - - MapEstimate - - - - - - - OnlinePositionEstimate - - - PositionBins - - - OfflineDecoderResults - - - MapEstimate - - - - - - - OfflinePositionEstimate - - Visualizer + RunDecoder - OnlinePositionEstimate + Data - - - OfflinePositionEstimate + + Spikes + + + + spikes + + + + + model.decode(spikes) + + + + + + + OnlineDecoderResults + + + Data + + + OnlinePosterior - + Data + + + Position + + + Convert.ToDouble(it.ToString()) + + Position - - OnlinePosterior + Data - - + + Decoding + + + + + + OfflineDecoderResults + + + Data + + OfflinePosterior + + Data + + + PositionBins + + + + + + PositionBins + + + PositionBins + + + OnlineDecoderResults + + + MapEstimate + + + + + + + OnlinePositionEstimate + + + PositionBins + + + OfflineDecoderResults + + + MapEstimate + + + + + + + OfflinePositionEstimate + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + VisualizeDecoder + + + + PositionEstimates + + + + true + true + Online Prediction + + + + true + true + Offline Acausal Prediction + + + + true + true + True Position + + + + OnlinePositionEstimate + + + + OfflinePositionEstimate + + + + Position + + + + 3 + 2 + + + + Percent + 10 + + + Percent + 90 + + + + + + + + + + + + + + + + + + + + + + + + + + Posteriors + + + + true + true + Online Posterior + + + + true + true + Offline Acausal Posterior + + + + OnlinePosterior + + + + OfflinePosterior + + + + true + true + 2 + 2 + + + + Percent + 10 + + + Percent + 90 + + + + + + + + + + + + + + + + + + + true true - 3 + 1 2 @@ -275,16 +421,10 @@ from decoder import * - + - + - - - - - - @@ -295,34 +435,6 @@ from decoder import * - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb index d1459a7..2b57902 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb @@ -41,7 +41,7 @@ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", "spikes_filename = \"../../../../datasets/decoder_data/clusterless_spike_times.pkl\"\n", "features_filename = \"../../../../datasets/decoder_data/clusterless_spike_features.pkl\"\n", - "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder.pkl\"\n", + "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl\"\n", "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl\"" ] }, From 9bc727f77a62c3326e2c267549490cab772f8f98 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 8 Nov 2024 15:07:47 +0000 Subject: [PATCH 19/34] Added README to example --- .../PositionDecodingFromHippocampus/README.md | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md index 843caed..0799111 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md @@ -1,13 +1,46 @@ # Position Decoding from Hippocampal Sorted Spikes -In the following example, you can find how to use the decoder from [here](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/tree/master?tab=readme-ov-file) to decode an animals position from sorted hippocampal units sampled with tetrodes. +In the following example, you can find how to use the decoder from [here](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/tree/master?tab=readme-ov-file) to decode an position from hippocampal activity. -### Dataset +## Dataset -We thank Eric Denovellis for sharing his data and for his help with the decoder. Please cite his work: Eric L Denovellis, Anna K Gillespie, Michael E Coulter, Marielena Sosa, Jason E Chung, Uri T Eden, Loren M Frank (2021). Hippocampal replay of experience at real-world speeds. eLife 10:e64505. +We thank Eric Denovellis for sharing his data and for his help with the decoder. If you use this example dataset, please cite: Eric L Denovellis, Anna K Gillespie, Michael E Coulter, Marielena Sosa, Jason E Chung, Uri T Eden, Loren M Frank (2021). Hippocampal replay of experience at real-world speeds. eLife 10:e64505. -You can download the data [here](https://drive.google.com/file/d/1ddRC28w0U4_q3pcGfY-1vPHjO9mjEaJb/view?usp=sharing). The workflow expects the zip file to be extracted into the `datasets/decoder_data` folder. The workflow also expects the files to be renamed to just `sorted_spike_times.pkl` and `position_info.pkl`. +## Installation ### Python -You need to install the [replay_trajectory_classification](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification) package into your python virtual environment. +You can bootstrap the python environment by running: + +```python +cd \path\to\examples\NeuralDecoding\PositionDecodingFromHippocampus +python -m venv .venv +.\.venv\Scripts\activate +pip install -r requirements.txt +``` + +You can test whether the installation was successful by launching python and running `import replay_trajectory_classification`. + +### Bonsai + +You can bootstrap the bonsai environment using: + +``` +cd \path\to\examples\NeuralDecoding\PositionDecodingFromHippocampus +dotnet new bonsaienv --allow-scripts yes +``` + +Alternatively, you can copy the `.bonsai\Bonsai.config` file into your Bonsai installation folder. You can test if it worked by openning bonsai and searching for the `CreateRuntime` node, which should appear in the toolbox. + +### Training the decoder offline + +You first need to train the decoder model and save it to disk. Open up the `notebooks` folder and select either `SortedSpikeDecoder.ipynb` or `ClusterlessDecoder.ipynb` depending on which model type you would like to use. Run the notebook. Once completed, this should create 2 new files: 1) `datasets\decoder_data\[ModelType]_decoder.pkl` for the trained decoder model; and 2) `datasets\decoder_data\[ModelType]_decoding_results.pkl` for the predictions. + +### Running the decoder online + +Launch the Bonsai.exe file inside of the .bonsai folder and open the workflow corresponding to the model type you used in the previous step. Press the `Start Workflow` button. The workflow may take some time to initialize and load the data. Once the workflow is running, open the `VisualizeDecoder` node to bring up the following: +1. Online prediction of position +2. Offline acausal prediction of position +3. True position +4. Latest online posterior distribution +5. Latest offline acausal posterior distribution From c022daafc1b3081940394021f634d8a49e1d9379 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 8 Nov 2024 15:07:59 +0000 Subject: [PATCH 20/34] Added requirements.txt file --- .../requirements.txt | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt b/examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt new file mode 100644 index 0000000..98692e0 --- /dev/null +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt @@ -0,0 +1,77 @@ +asttokens==2.4.1 +click==8.1.7 +cloudpickle==3.1.0 +comm==0.2.2 +contourpy==1.3.0 +cupy-cuda12x==13.3.0 +cycler==0.12.1 +dask==2024.10.0 +debugpy==1.8.7 +decorator==5.1.1 +distributed==2024.10.0 +executing==2.1.0 +fastrlock==0.8.2 +fonttools==4.54.1 +fsspec==2024.10.0 +imageio==2.36.0 +ipykernel==6.29.5 +ipython==8.29.0 +jedi==0.19.1 +Jinja2==3.1.4 +joblib==1.4.2 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +kiwisolver==1.4.7 +lazy_loader==0.4 +llvmlite==0.43.0 +locket==1.0.0 +MarkupSafe==3.0.2 +matplotlib==3.9.2 +matplotlib-inline==0.1.7 +msgpack==1.1.0 +nest-asyncio==1.6.0 +networkx==3.4.2 +numba==0.60.0 +numpy==2.0.2 +packaging==24.1 +pandas==2.2.3 +parso==0.8.4 +partd==1.4.2 +patsy==0.5.6 +pexpect==4.9.0 +pillow==11.0.0 +platformdirs==4.3.6 +prompt_toolkit==3.0.48 +psutil==6.1.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +Pygments==2.18.0 +pyparsing==3.2.0 +python-dateutil==2.9.0.post0 +pytz==2024.2 +PyYAML==6.0.2 +pyzmq==26.2.0 +regularized-glm==1.0.2 +-e git+https://github.com/Eden-Kramer-Lab/replay_trajectory_classification.git@a61ba5c52f368ac28329806c4e081dc51f5f03fb#egg=replay_trajectory_classification +scikit-image==0.24.0 +scikit-learn==1.5.2 +scipy==1.14.1 +seaborn==0.13.2 +setuptools==75.3.0 +six==1.16.0 +sortedcontainers==2.4.0 +stack-data==0.6.3 +statsmodels==0.14.4 +tblib==3.0.0 +threadpoolctl==3.5.0 +tifffile==2024.9.20 +toolz==1.0.0 +tornado==6.4.1 +tqdm==4.66.6 +track_linearization==2.3.2 +traitlets==5.14.3 +tzdata==2024.2 +urllib3==2.2.3 +wcwidth==0.2.13 +xarray==2024.10.0 +zict==3.0.0 From 8a775f7e1d104d8fd876aaa27bd7dd12f62eda26 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 8 Nov 2024 15:08:19 +0000 Subject: [PATCH 21/34] Updated notebook for 50Hz down sampling --- .../ClusterlessSpikes.bonsai | 3 +- .../notebooks/ClusterlessDecoder.ipynb | 106 ++++++++++++------ 2 files changed, 75 insertions(+), 34 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai index 23426a9..f323a38 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai @@ -27,6 +27,7 @@ import sys import os sys.path.append(os.getcwd()) +sys.path.append(os.path.join(os.getcwd(), ".venv/Scripts")) from decoder import * @@ -47,7 +48,7 @@ from decoder import * - decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl") + decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int_50Hz.pkl") diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb index 2b57902..f47b1d1 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -34,33 +34,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", "spikes_filename = \"../../../../datasets/decoder_data/clusterless_spike_times.pkl\"\n", "features_filename = \"../../../../datasets/decoder_data/clusterless_spike_features.pkl\"\n", - "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl\"\n", - "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl\"" + "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int_50Hz.pkl\"\n", + "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results_50Hz.pkl\"" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "positions_df = pd.read_pickle(positions_filename)\n", - "timestamps = positions_df.index.to_numpy()\n", - "dt = timestamps[1] - timestamps[0]\n", + "time_start = positions_df.index.to_numpy()[0]\n", + "time_end = positions_df.index.to_numpy()[-1]\n", + "dt = 0.02\n", "Fs = 1.0 / dt\n", - "spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2)" + "spikes_bins = np.arange(time_start - dt, time_end + dt, dt)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -475,7 +476,7 @@ "[452325 rows x 33 columns]" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -486,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -512,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -531,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -541,22 +542,26 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "with open(spikes_filename, \"rb\") as f:\n", " clusterless_spike_times = pkl.load(f)\n", "\n", - "features = np.ones((len(timestamps), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n", + "features = np.ones((len(spikes_bins - 1), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n", "for n in range(len(clusterless_spike_times)):\n", " in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins)\n", - " features[in_spikes_window, :, n] = clusterless_spike_features[n]" + " features[in_spikes_window, :, n] = np.nanmax([features[in_spikes_window, :, n], clusterless_spike_features[n]], axis=0)\n", + "\n", + "linear_position = np.ones(len(spikes_bins - 1)) * np.nan\n", + "in_position_window = np.digitize(positions_df.index, spikes_bins)\n", + "linear_position[in_position_window] = linearized_positions.linear_position" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -579,7 +584,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1009,7 +1014,7 @@ " 'position_std': 6.0},\n", " environment=Environment(environment_name='',\n", " place_bin_size=0.5,\n", - " track_graph=<networkx.classes.graph.Graph object at 0x73b71c0ee570>,\n", + " track_graph=<networkx.classes.graph.Graph object at 0x77606450df10>,\n", " edge_order=[(3, 2), (0, 1), (1, 2),\n", " (5, 4), (4, 2)],\n", " edge_spacing=[16, 0, 16, 0],\n", @@ -1029,7 +1034,7 @@ " 'position_std': 6.0},\n", " environment=Environment(environment_name='',\n", " place_bin_size=0.5,\n", - " track_graph=<networkx.classes.graph.Graph object at 0x73b71c0ee570>,\n", + " track_graph=<networkx.classes.graph.Graph object at 0x77606450df10>,\n", " edge_order=[(3, 2), (0, 1), (1, 2),\n", " (5, 4), (4, 2)],\n", " edge_spacing=[16, 0, 16, 0],\n", @@ -1052,7 +1057,7 @@ " 'position_std': 6.0},\n", " environment=Environment(environment_name='',\n", " place_bin_size=0.5,\n", - " track_graph=,\n", + " track_graph=,\n", " edge_order=[(3, 2), (0, 1), (1, 2),\n", " (5, 4), (4, 2)],\n", " edge_spacing=[16, 0, 16, 0],\n", @@ -1070,33 +1075,33 @@ " use_diffusion=False))" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Learning model parameters\")\n", - "decoder.fit(linearized_positions.linear_position, features)" + "decoder.fit(linear_position, features)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Saving model to ../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl\n" + "Saving model to ../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int_50Hz.pkl\n" ] } ], "source": [ "print(f\"Saving model to {model_filename}\")\n", "\n", - "results = dict(decoder=decoder, linearized_positions=linearized_positions,\n", + "results = dict(decoder=decoder, linear_position=linear_position,\n", " clusterless_spike_times=clusterless_spike_times, features=features, Fs=Fs)\n", "\n", "with open(model_filename, \"wb\") as f:\n", @@ -1112,7 +1117,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -1133,7 +1138,7 @@ "Fs = model_results[\"Fs\"]\n", "clusterless_spike_times = model_results[\"clusterless_spike_times\"]\n", "features = model_results[\"features\"]\n", - "linearized_positions = model_results[\"linearized_positions\"]" + "linear_position = model_results[\"linear_position\"]" ] }, { @@ -1152,7 +1157,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "n_electrodes: 100%|██████████| 28/28 [00:10<00:00, 2.55it/s]\n", + "n_electrodes: 100%|██████████| 28/28 [00:02<00:00, 9.94it/s]\n", "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/core.py:84: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'F', False, aligned=True), Array(float64, 1, 'A', False, aligned=True))\n", " posterior[k] = state_transition.T @ posterior[k - 1] * likelihood[k]\n", "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/core.py:116: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'F', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n", @@ -1165,20 +1170,20 @@ "decoding_start_samples = int(decoding_start_secs * Fs)\n", "decoding_duration_samples = int(decoding_duration_secs * Fs)\n", "time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)\n", - "time = np.arange(linearized_positions.linear_position.size) / Fs\n", + "time = np.arange(linear_position.size) / Fs\n", "decoding_results = decoder.predict(features[time_ind], time=time[time_ind])" ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Saving decoding results to ../../../../datasets/decoder_data/clusterless_decoding_results.pkl\n" + "Saving decoding results to ../../../../datasets/decoder_data/clusterless_spike_decoding_results_50Hz.pkl\n" ] } ], @@ -1186,13 +1191,48 @@ "print(f\"Saving decoding results to {decoding_filename}\")\n", "\n", "results = dict(decoding_results=decoding_results, time=time[time_ind],\n", - " linearized_positions=linearized_positions.iloc[time_ind],\n", + " linear_position=linear_position[time_ind],\n", " features=features[time_ind])\n", "\n", "with open(decoding_filename, \"wb\") as f:\n", " pkl.dump(results, f)" ] }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'linear_position'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[21], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m decoding_results \u001b[38;5;241m=\u001b[39m load_res[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdecoding_results\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 5\u001b[0m time \u001b[38;5;241m=\u001b[39m load_res[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtime\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m----> 6\u001b[0m linear_position \u001b[38;5;241m=\u001b[39m \u001b[43mload_res\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlinear_position\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'linear_position'" + ] + } + ], + "source": [ + "with open(decoding_filename, \"rb\") as f:\n", + " load_res = pkl.load(f)\n", + "\n", + "decoding_results = load_res[\"decoding_results\"]\n", + "time = load_res[\"time\"]\n", + "linear_position = load_res[\"linear_position\"]" + ] + }, { "cell_type": "code", "execution_count": null, From 031c46d8791d6a368906471fc30ba6cfcafd9ba1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 12:27:55 +0000 Subject: [PATCH 22/34] Added cell for plotting --- .../notebooks/ClusterlessDecoder.ipynb | 1038 +---------------- 1 file changed, 48 insertions(+), 990 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb index f47b1d1..01f2ff9 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb @@ -9,18 +9,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/multiunit_likelihood.py:8: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from tqdm.autonotebook import tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "import sys\n", @@ -34,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -47,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -61,433 +52,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nose_xnose_ynose_veltailBase_xtailBase_ytailBase_veltailMid_xtailMid_ytailMid_veltailTip_x...hindpawR_velforelimb_mid_xforelimb_mid_yforelimb_velbody_dirlinear_positiontrack_segment_idprojected_x_positionprojected_y_positionarm_name
time
22389.0828756.3026285.23117445.32077012.8560645.54749648.29090911.3679136.83302156.50818616.296467...NaN5.0806713.215346125.7070052.338472162.28501936.3451018.359380Left Arm
22389.0848756.7552765.64450345.57971613.5034485.68928648.48375112.1428277.31534556.75125617.387420...NaN5.1946843.302497124.9766022.347791162.73201436.7920558.353311Left Arm
22389.0868757.2079256.05783345.83866114.1508325.83107648.67659312.9177407.79766956.99432518.478373...NaN5.3086973.389648124.2462002.357110163.17901037.2390098.347243Left Arm
22389.0888757.6605736.47116246.09760714.7982155.97286648.86943513.6926548.27999357.23739519.569326...NaN5.4227093.476798123.5157972.366429163.62600537.6859638.341174Left Arm
22389.0908758.1132226.88449246.35655215.4455996.11465749.06227714.4675688.76231757.48046420.660279...NaN5.5367223.563949122.7853942.375748164.07300038.1329178.335106Left Arm
..................................................................
23293.7228750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7248750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7268750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7288750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7308750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
\n", - "

452325 rows × 33 columns

\n", - "
" - ], - "text/plain": [ - " nose_x nose_y nose_vel tailBase_x tailBase_y \\\n", - "time \n", - "22389.082875 6.302628 5.231174 45.320770 12.856064 5.547496 \n", - "22389.084875 6.755276 5.644503 45.579716 13.503448 5.689286 \n", - "22389.086875 7.207925 6.057833 45.838661 14.150832 5.831076 \n", - "22389.088875 7.660573 6.471162 46.097607 14.798215 5.972866 \n", - "22389.090875 8.113222 6.884492 46.356552 15.445599 6.114657 \n", - "... ... ... ... ... ... \n", - "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "\n", - " tailBase_vel tailMid_x tailMid_y tailMid_vel tailTip_x ... \\\n", - "time ... \n", - "22389.082875 48.290909 11.367913 6.833021 56.508186 16.296467 ... \n", - "22389.084875 48.483751 12.142827 7.315345 56.751256 17.387420 ... \n", - "22389.086875 48.676593 12.917740 7.797669 56.994325 18.478373 ... \n", - "22389.088875 48.869435 13.692654 8.279993 57.237395 19.569326 ... \n", - "22389.090875 49.062277 14.467568 8.762317 57.480464 20.660279 ... \n", - "... ... ... ... ... ... ... \n", - "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "\n", - " hindpawR_vel forelimb_mid_x forelimb_mid_y forelimb_vel \\\n", - "time \n", - "22389.082875 NaN 5.080671 3.215346 125.707005 \n", - "22389.084875 NaN 5.194684 3.302497 124.976602 \n", - "22389.086875 NaN 5.308697 3.389648 124.246200 \n", - "22389.088875 NaN 5.422709 3.476798 123.515797 \n", - "22389.090875 NaN 5.536722 3.563949 122.785394 \n", - "... ... ... ... ... \n", - "23293.722875 0.0 0.000000 0.000000 0.000000 \n", - "23293.724875 0.0 0.000000 0.000000 0.000000 \n", - "23293.726875 0.0 0.000000 0.000000 0.000000 \n", - "23293.728875 0.0 0.000000 0.000000 0.000000 \n", - "23293.730875 0.0 0.000000 0.000000 0.000000 \n", - "\n", - " body_dir linear_position track_segment_id \\\n", - "time \n", - "22389.082875 2.338472 162.285019 3 \n", - "22389.084875 2.347791 162.732014 3 \n", - "22389.086875 2.357110 163.179010 3 \n", - "22389.088875 2.366429 163.626005 3 \n", - "22389.090875 2.375748 164.073000 3 \n", - "... ... ... ... \n", - "23293.722875 -0.000057 161.447258 3 \n", - "23293.724875 -0.000057 161.447258 3 \n", - "23293.726875 -0.000057 161.447258 3 \n", - "23293.728875 -0.000057 161.447258 3 \n", - "23293.730875 -0.000057 161.447258 3 \n", - "\n", - " projected_x_position projected_y_position arm_name \n", - "time \n", - "22389.082875 6.345101 8.359380 Left Arm \n", - "22389.084875 6.792055 8.353311 Left Arm \n", - "22389.086875 7.239009 8.347243 Left Arm \n", - "22389.088875 7.685963 8.341174 Left Arm \n", - "22389.090875 8.132917 8.335106 Left Arm \n", - "... ... ... ... \n", - "23293.722875 5.507417 8.370753 Left Arm \n", - "23293.724875 5.507417 8.370753 Left Arm \n", - "23293.726875 5.507417 8.370753 Left Arm \n", - "23293.728875 5.507417 8.370753 Left Arm \n", - "23293.730875 5.507417 8.370753 Left Arm \n", - "\n", - "[452325 rows x 33 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "positions_df" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -513,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -532,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -542,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -552,7 +126,7 @@ "features = np.ones((len(spikes_bins - 1), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n", "for n in range(len(clusterless_spike_times)):\n", " in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins)\n", - " features[in_spikes_window, :, n] = np.nanmax([features[in_spikes_window, :, n], clusterless_spike_features[n]], axis=0)\n", + " features[in_spikes_window, :, n] = clusterless_spike_features[n]\n", "\n", "linear_position = np.ones(len(spikes_bins - 1)) * np.nan\n", "in_position_window = np.digitize(positions_df.index, spikes_bins)\n", @@ -561,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -584,502 +158,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Learning model parameters\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/continuous_state_transitions.py:24: RuntimeWarning: invalid value encountered in divide\n", - " x /= x.sum(axis=1, keepdims=True)\n" - ] - }, - { - "data": { - "text/html": [ - "
ClusterlessDecoder(clusterless_algorithm='multiunit_likelihood_integer_gpu',\n",
-       "                   clusterless_algorithm_params={'mark_std': 24.0,\n",
-       "                                                 'position_std': 6.0},\n",
-       "                   environment=Environment(environment_name='',\n",
-       "                                           place_bin_size=0.5,\n",
-       "                                           track_graph=<networkx.classes.graph.Graph object at 0x77606450df10>,\n",
-       "                                           edge_order=[(3, 2), (0, 1), (1, 2),\n",
-       "                                                       (5, 4), (4, 2)],\n",
-       "                                           edge_spacing=[16, 0, 16, 0],\n",
-       "                                           is_track_interior=None,\n",
-       "                                           position_range=None,\n",
-       "                                           infer_track_interior=True,\n",
-       "                                           fill_holes=False,\n",
-       "                                           dilate=False,\n",
-       "                                           bin_count_threshold=0),\n",
-       "                   infer_track_interior=True,\n",
-       "                   initial_conditions_type=UniformInitialConditions(),\n",
-       "                   transition_type=RandomWalk(environment_name='',\n",
-       "                                              movement_var=0.25,\n",
-       "                                              movement_mean=0.0,\n",
-       "                                              use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" - ], - "text/plain": [ - "ClusterlessDecoder(clusterless_algorithm='multiunit_likelihood_integer_gpu',\n", - " clusterless_algorithm_params={'mark_std': 24.0,\n", - " 'position_std': 6.0},\n", - " environment=Environment(environment_name='',\n", - " place_bin_size=0.5,\n", - " track_graph=,\n", - " edge_order=[(3, 2), (0, 1), (1, 2),\n", - " (5, 4), (4, 2)],\n", - " edge_spacing=[16, 0, 16, 0],\n", - " is_track_interior=None,\n", - " position_range=None,\n", - " infer_track_interior=True,\n", - " fill_holes=False,\n", - " dilate=False,\n", - " bin_count_threshold=0),\n", - " infer_track_interior=True,\n", - " initial_conditions_type=UniformInitialConditions(),\n", - " transition_type=RandomWalk(environment_name='',\n", - " movement_var=0.25,\n", - " movement_mean=0.0,\n", - " use_diffusion=False))" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "print(\"Learning model parameters\")\n", "decoder.fit(linear_position, features)" @@ -1087,17 +168,9 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Saving model to ../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int_50Hz.pkl\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"Saving model to {model_filename}\")\n", "\n", @@ -1117,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1127,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1143,28 +216,9 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Decoding positions from features\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "n_electrodes: 100%|██████████| 28/28 [00:02<00:00, 9.94it/s]\n", - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/core.py:84: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'F', False, aligned=True), Array(float64, 1, 'A', False, aligned=True))\n", - " posterior[k] = state_transition.T @ posterior[k - 1] * likelihood[k]\n", - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/core.py:116: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'F', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n", - " acausal_prior = state_transition.T @ causal_posterior[time_ind]\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Decoding positions from features\")\n", "decoding_start_samples = int(decoding_start_secs * Fs)\n", @@ -1178,15 +232,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Saving decoding results to ../../../../datasets/decoder_data/clusterless_spike_decoding_results_50Hz.pkl\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"Saving decoding results to {decoding_filename}\")\n", "\n", @@ -1200,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1209,21 +255,9 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "KeyError", - "evalue": "'linear_position'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[21], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m decoding_results \u001b[38;5;241m=\u001b[39m load_res[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdecoding_results\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 5\u001b[0m time \u001b[38;5;241m=\u001b[39m load_res[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtime\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m----> 6\u001b[0m linear_position \u001b[38;5;241m=\u001b[39m \u001b[43mload_res\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlinear_position\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'linear_position'" - ] - } - ], + "outputs": [], "source": [ "with open(decoding_filename, \"rb\") as f:\n", " load_res = pkl.load(f)\n", @@ -1233,6 +267,30 @@ "linear_position = load_res[\"linear_position\"]" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = go.Figure()\n", + "\n", + "trace = go.Heatmap(z=decoding_results.acausal_posterior.T,\n", + " x=decoding_results.acausal_posterior.time,\n", + " y=decoding_results.acausal_posterior.position,\n", + " zmin=0.00, zmax=0.05, showscale=False)\n", + "fig.add_trace(trace)\n", + "\n", + "trace = go.Scatter(x=time, y=linear_position,\n", + " mode=\"markers\", marker={\"color\": \"cyan\", \"size\": 5},\n", + " name=\"position\", showlegend=True)\n", + "fig.add_trace(trace)\n", + "\n", + "fig.update_xaxes(title=\"Time (sec)\")\n", + "fig.update_yaxes(title=\"Position (cm)\")\n", + "fig.update_coloraxes(showscale=False)" + ] + }, { "cell_type": "code", "execution_count": null, From c8ee1a4be2a150d69f6f3635e3d0e9e67a3adeab Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 12:28:21 +0000 Subject: [PATCH 23/34] Added online FPS estimate --- .../ClusterlessSpikes.bonsai | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai index f323a38..05ad54a 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai @@ -138,6 +138,48 @@ from decoder import * Data + + + + + Interval.TotalMilliseconds + + + 1000/it + + + + 10 + 1 + + + + + + + Source1 + + + + 1 + + + + + + + + + + + + + + + + + + RunDecoder @@ -436,6 +478,11 @@ from decoder import * + + + + + \ No newline at end of file From b62ad61bcc4eab19b25779f5a260dd8cdb6c6806 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 12:39:11 +0000 Subject: [PATCH 24/34] Updated to properly use 50 Hz --- .../ClusterlessSpikes.bonsai | 2 +- .../decoder/data_loader.py | 70 ++++--------------- 2 files changed, 16 insertions(+), 56 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai index 05ad54a..5143e88 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai @@ -219,7 +219,7 @@ from decoder import * Position - Convert.ToDouble(it.ToString()) + it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString()) Position diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py index 18c229a..5aafca2 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py @@ -75,63 +75,23 @@ def load_sorted_spike_data(cls, @classmethod def load_clusterless_spike_data(cls, dataset_path: str = "../../../datasets/decoder_data") -> dict: - - if len([file for file in os.listdir(dataset_path) if file == "position_info.pkl" or file == "clusterless_spike_times.pkl" or file == "clusterless_spike_features.pkl" or file == "clusterless_spike_decoding_results.pkl"]) != 4: - raise Exception("Dataset incorrect. Missing at least one of the following files: 'position_info.pkl', 'clusterless_spike_times.pkl', 'clusterless_spike_features.pkl', 'clusterless_spike_decoding_results.pkl'") - - position_data = pd.read_pickle(os.path.join(dataset_path, "position_info.pkl")) - position_index = position_data.index.to_numpy() - position_index = np.insert(position_index, 0, position_index[0] - (position_index[1] - position_index[0])) - position_data = position_data[["nose_x", "nose_y"]].to_numpy() - - node_positions = [(120.0, 100.0), - ( 5.0, 100.0), - ( 5.0, 55.0), - (120.0, 55.0), - ( 5.0, 8.5), - (120.0, 8.5), - ] - edges = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - track_graph = rtc.make_track_graph(node_positions, edges) - - edge_order = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - - edge_spacing = [16, 0, 16, 0] - - linearized_positions = tl.get_linearized_position(position_data, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False) - position_data = linearized_positions.linear_position - - with open(os.path.join(dataset_path, "clusterless_spike_times.pkl"), "rb") as f: - spike_times = pickle.load(f) - - with open(os.path.join(dataset_path, "clusterless_spike_features.pkl"), "rb") as f: - spike_features = pickle.load(f) - - features = np.ones((len(position_data), len(spike_features[0][0]), len(spike_times)), dtype=float) * np.nan - for n in range(len(spike_times)): - in_spikes_window = np.digitize(spike_times[n], position_index) - features[in_spikes_window, :, n] = spike_features[n] - - with open(os.path.join(dataset_path, "clusterless_spike_decoding_results.pkl"), "rb") as f: - results = pickle.load(f)["decoding_results"] - position_bins = results.position.to_numpy()[np.newaxis] - decoding_results = results.acausal_posterior.to_numpy()[:,np.newaxis] + + decoding_results_filename = os.path.join(dataset_path, "clusterless_spike_decoding_results_50Hz.pkl") + if not os.path.exists(decoding_results_filename): + raise Exception("Dataset incorrect. Missing 'clusterless_spike_decoding_results_50Hz.pkl'") + + with open(decoding_results_filename, "rb") as f: + results = pickle.load(f) + decoding_results = results["decoding_results"] + position_bins = decoding_results.position.to_numpy()[np.newaxis] + decoding_results = decoding_results.acausal_posterior.to_numpy()[:,np.newaxis] + features = results["features"] + time = results["time"] + linear_position = results["linear_position"] return { - "position_data": position_data, - "spike_times": spike_times, + "linear_position": linear_position, + "time": time, "features": features, "decoding_results": decoding_results, "position_bins": position_bins From 2c9cfda32bc0c12e82e7324ed0705c422e580269 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 15 Nov 2024 16:26:09 +0000 Subject: [PATCH 25/34] Remove exensions and use new Bonsai.ML.NeuralDecoding package --- .../Extensions.csproj | 1 - .../Extensions/PositionBins.cs | 19 --------- .../Extensions/Posterior.cs | 42 ------------------- 3 files changed, 62 deletions(-) delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions.csproj delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/PositionBins.cs delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/Posterior.cs diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions.csproj b/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions.csproj deleted file mode 100644 index e99fc36..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions.csproj +++ /dev/null @@ -1 +0,0 @@ -net472 \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/PositionBins.cs b/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/PositionBins.cs deleted file mode 100644 index f0e4a12..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/PositionBins.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using Python.Runtime; -using Bonsai.ML.Python; - -[Combinator] -[Description("")] -[WorkflowElementCategory(ElementCategory.Transform)] -public class PositionBins -{ - public IObservable Process(IObservable source) - { - return source.Select(value => (double[])PythonHelper.ConvertPythonObjectToCSharp(value)); - } -} diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/Posterior.cs b/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/Posterior.cs deleted file mode 100644 index 21364bd..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/Extensions/Posterior.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using Python.Runtime; -using Bonsai.ML.Python; - -[Combinator] -[Description("")] -[WorkflowElementCategory(ElementCategory.Transform)] -public class Posterior -{ - public IObservable Process(IObservable source) - { - return source.Select(value => { - return new PosteriorData(value); - }); - } -} - -public class PosteriorData -{ - public PosteriorData(PyObject posterior) - { - _data = (double[,])PythonHelper.ConvertPythonObjectToCSharp(posterior); - _mapEstimate = 0; - for (int i = 1; i < _data.GetLength(1); i++) - { - if (_data[0, i] > _data[0, _mapEstimate]) - { - _mapEstimate = i; - } - } - } - public double[,] Data => _data; - private double[,] _data; - - public int MapEstimate => _mapEstimate; - private int _mapEstimate; -} From ffd1014d6b3d8b77f92a5db2a864e9536008e830 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 18 Nov 2024 15:34:41 +0000 Subject: [PATCH 26/34] Updated to use Bonsai.ML package --- .../ClusterlessSpikes.bonsai | 309 +++--------------- .../SortedSpikes.bonsai | 304 +++++------------ 2 files changed, 125 insertions(+), 488 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai index 5143e88..19878c1 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai @@ -4,9 +4,7 @@ xmlns:py="clr-namespace:Bonsai.Scripting.Python;assembly=Bonsai.Scripting.Python" xmlns:rx="clr-namespace:Bonsai.Reactive;assembly=Bonsai.Core" xmlns:scr="clr-namespace:Bonsai.Scripting.Expressions;assembly=Bonsai.Scripting.Expressions" - xmlns:p1="clr-namespace:;assembly=Extensions" - xmlns:gui="clr-namespace:Bonsai.Gui;assembly=Bonsai.Gui" - xmlns:viz="clr-namespace:Bonsai.Design.Visualizers;assembly=Bonsai.Design.Visualizers" + xmlns:p1="clr-namespace:Bonsai.ML.NeuralDecoding;assembly=Bonsai.ML.NeuralDecoding" xmlns="https://bonsai-rx.org/2018/workflow"> @@ -118,7 +116,7 @@ from decoder import * it[0] - new(it[0] as Position, it[1] as Spikes, it[2] as Features, it[3] as Decoding, it[4] as PositionBins) + new(it[0] as Position, it[1] as Features, it[2] as Decoding, it[3] as PositionBins) @@ -138,37 +136,54 @@ from decoder import * Data - - - - - Interval.TotalMilliseconds - - - 1000/it - - - - 10 - 1 - - - + + Performance Source1 - - 1 - + - - + + Interval.TotalMilliseconds + + + 1000/it - + + 10 + 1 + + + + + + + Source1 + + + + 1 + + + + + + + + + + + + + + + + + @@ -177,6 +192,8 @@ from decoder import * + + @@ -201,17 +218,13 @@ from decoder import * - + + 0 + OnlineDecoderResults - - Data - - - OnlinePosterior - Data @@ -224,68 +237,6 @@ from decoder import * Position - - Data - - - Decoding - - - - - - OfflineDecoderResults - - - Data - - - OfflinePosterior - - - Data - - - PositionBins - - - - - - PositionBins - - - PositionBins - - - OnlineDecoderResults - - - MapEstimate - - - - - - - OnlinePositionEstimate - - - PositionBins - - - OfflineDecoderResults - - - MapEstimate - - - - - - - OfflinePositionEstimate - @@ -293,29 +244,9 @@ from decoder import * - + - - - - - - - - - - - - - - - - - - - - @@ -323,151 +254,19 @@ from decoder import * VisualizeDecoder - - PositionEstimates - - - - true - true - Online Prediction - - - - true - true - Offline Acausal Prediction - - - - true - true - True Position - - - - OnlinePositionEstimate - - - - OfflinePositionEstimate - - - - Position - - - - 3 - 2 - - - - Percent - 10 - - - Percent - 90 - - - - - - - - - - - - - - - - - - - - - - - - - - Posteriors - - - - true - true - Online Posterior - - - - true - true - Offline Acausal Posterior - - - - OnlinePosterior - - - - OfflinePosterior - - - - true - true - 2 - 2 - - - - Percent - 10 - - - Percent - 90 - - - - - - - - - - - - - - - - - - + + Position - - true - true - 1 - 2 - - - + + OnlineDecoderResults - + - - @@ -479,10 +278,6 @@ from decoder import * - - - - \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai index 13efb12..cdaaaec 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai @@ -4,9 +4,7 @@ xmlns:py="clr-namespace:Bonsai.Scripting.Python;assembly=Bonsai.Scripting.Python" xmlns:rx="clr-namespace:Bonsai.Reactive;assembly=Bonsai.Core" xmlns:scr="clr-namespace:Bonsai.Scripting.Expressions;assembly=Bonsai.Scripting.Expressions" - xmlns:p1="clr-namespace:;assembly=Extensions" - xmlns:gui="clr-namespace:Bonsai.Gui;assembly=Bonsai.Gui" - xmlns:viz="clr-namespace:Bonsai.Design.Visualizers;assembly=Bonsai.Design.Visualizers" + xmlns:p1="clr-namespace:Bonsai.ML.NeuralDecoding;assembly=Bonsai.ML.NeuralDecoding" xmlns="https://bonsai-rx.org/2018/workflow"> @@ -137,6 +135,67 @@ from decoder import * Data + + Performance + + + + Source1 + + + + + + Interval.TotalMilliseconds + + + 1000/it + + + + 10 + 1 + + + + + + + Source1 + + + + 1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + RunDecoder @@ -158,17 +217,13 @@ from decoder import * - + + 0 + OnlineDecoderResults - - Data - - - OnlinePosterior - Data @@ -176,73 +231,11 @@ from decoder import * Position - Convert.ToDouble(it.ToString()) + it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString()) Position - - Data - - - Decoding - - - - - - OfflineDecoderResults - - - Data - - - OfflinePosterior - - - Data - - - PositionBins - - - - - - PositionBins - - - PositionBins - - - OnlineDecoderResults - - - MapEstimate - - - - - - - OnlinePositionEstimate - - - PositionBins - - - OfflineDecoderResults - - - MapEstimate - - - - - - - OfflinePositionEstimate - @@ -250,29 +243,9 @@ from decoder import * - + - - - - - - - - - - - - - - - - - - - - @@ -280,151 +253,19 @@ from decoder import * VisualizeDecoder - - PositionEstimates - - - - true - true - Online Prediction - - - - true - true - Offline Acausal Prediction - - - - true - true - True Position - - - - OnlinePositionEstimate - - - - OfflinePositionEstimate - - - - Position - - - - 3 - 2 - - - - Percent - 10 - - - Percent - 90 - - - - - - - - - - - - - - - - - - - - - - - - - - Posteriors - - - - true - true - Online Posterior - - - - true - true - Offline Acausal Posterior - - - - OnlinePosterior - - - - OfflinePosterior - - - - true - true - 2 - 2 - - - - Percent - 10 - - - Percent - 90 - - - - - - - - - - - - - - - - - - + + Position - - true - true - 1 - 2 - - - + + OnlineDecoderResults - + - - @@ -435,6 +276,7 @@ from decoder import * + \ No newline at end of file From 5cae1d2f32dcabc0c31629261177aac187453a60 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 18 Nov 2024 15:35:32 +0000 Subject: [PATCH 27/34] Removed requirements.txt file in favor of installing from git repo --- .../PositionDecodingFromHippocampus/README.md | 39 ++++++---- .../requirements.txt | 77 ------------------- 2 files changed, 25 insertions(+), 91 deletions(-) delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md index 0799111..771ccb3 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md @@ -1,25 +1,33 @@ # Position Decoding from Hippocampal Sorted Spikes -In the following example, you can find how to use the decoder from [here](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/tree/master?tab=readme-ov-file) to decode an position from hippocampal activity. +In the following example, you can find how to use the spike sorted decoder or clusterless spike decoder from [here](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/tree/master?tab=readme-ov-file) to decode position from hippocampal activity. ## Dataset -We thank Eric Denovellis for sharing his data and for his help with the decoder. If you use this example dataset, please cite: Eric L Denovellis, Anna K Gillespie, Michael E Coulter, Marielena Sosa, Jason E Chung, Uri T Eden, Loren M Frank (2021). Hippocampal replay of experience at real-world speeds. eLife 10:e64505. +We thank Eric Denovellis for sharing his data and for his help with the decoder. If you use this example dataset, please consider citing the work: Joshi, A., Denovellis, E.L., Mankili, A. et al. Dynamic synchronization between hippocampal representations and stepping. Nature 617, 125–131 (2023). https://doi.org/10.1038/s41586-023-05928-6. + +## Algorithm + +The neural decoder consists of a bayesian state-space model and point-processes to decode a latent variable (position) from neural spiking activity. To read more about the theory behind the model and how the algorithm works, we refer the reader to: Denovellis, E.L., Gillespie, A.K., Coulter, M.E., et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021). https://doi.org/10.7554/eLife.64505. ## Installation ### Python -You can bootstrap the python environment by running: +To install the package, run: -```python +``` cd \path\to\examples\NeuralDecoding\PositionDecodingFromHippocampus python -m venv .venv .\.venv\Scripts\activate -pip install -r requirements.txt +pip install git+https://github.com/ncguilbeault/bayesian-neural-decoder.git ``` -You can test whether the installation was successful by launching python and running `import replay_trajectory_classification`. +You can test whether the installation was successful by launching python and running + +```python +import bayesian_neural_decoder +``` ### Bonsai @@ -32,15 +40,18 @@ dotnet new bonsaienv --allow-scripts yes Alternatively, you can copy the `.bonsai\Bonsai.config` file into your Bonsai installation folder. You can test if it worked by openning bonsai and searching for the `CreateRuntime` node, which should appear in the toolbox. +## Usage + ### Training the decoder offline -You first need to train the decoder model and save it to disk. Open up the `notebooks` folder and select either `SortedSpikeDecoder.ipynb` or `ClusterlessDecoder.ipynb` depending on which model type you would like to use. Run the notebook. Once completed, this should create 2 new files: 1) `datasets\decoder_data\[ModelType]_decoder.pkl` for the trained decoder model; and 2) `datasets\decoder_data\[ModelType]_decoding_results.pkl` for the predictions. +The package contains 2 different models, one which takes as input sorted spike activity, and another which uses clusterless spike activity taken from raw ephys recordings. + +You first need to train the decoder model and save it to disk. Open up the `notebooks` folder and select either `SortedSpikeDecoder.ipynb` or `ClusterlessDecoder.ipynb` depending on which model type you would like to use. Run the notebook. Once completed, this will create 2 new files: +1) `datasets\decoder_data\[ModelType]_decoder.pkl` for the trained decoder model +2) `datasets\decoder_data\[ModelType]_decoding_results.pkl` for the predictions. + +Both of these files are needed to be run the decoder example in Bonsai. -### Running the decoder online +### Running the decoder online with Bonsai -Launch the Bonsai.exe file inside of the .bonsai folder and open the workflow corresponding to the model type you used in the previous step. Press the `Start Workflow` button. The workflow may take some time to initialize and load the data. Once the workflow is running, open the `VisualizeDecoder` node to bring up the following: -1. Online prediction of position -2. Offline acausal prediction of position -3. True position -4. Latest online posterior distribution -5. Latest offline acausal posterior distribution +Launch the Bonsai.exe file inside of the .bonsai folder and open the workflow corresponding to the model type you used in the previous step. Press the `Start Workflow` button. The workflow may take some time to initialize and load the data. Once the workflow is running, open the `VisualizeDecoder` node to see the model's inference online with respect to the true position. diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt b/examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt deleted file mode 100644 index 98692e0..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/requirements.txt +++ /dev/null @@ -1,77 +0,0 @@ -asttokens==2.4.1 -click==8.1.7 -cloudpickle==3.1.0 -comm==0.2.2 -contourpy==1.3.0 -cupy-cuda12x==13.3.0 -cycler==0.12.1 -dask==2024.10.0 -debugpy==1.8.7 -decorator==5.1.1 -distributed==2024.10.0 -executing==2.1.0 -fastrlock==0.8.2 -fonttools==4.54.1 -fsspec==2024.10.0 -imageio==2.36.0 -ipykernel==6.29.5 -ipython==8.29.0 -jedi==0.19.1 -Jinja2==3.1.4 -joblib==1.4.2 -jupyter_client==8.6.3 -jupyter_core==5.7.2 -kiwisolver==1.4.7 -lazy_loader==0.4 -llvmlite==0.43.0 -locket==1.0.0 -MarkupSafe==3.0.2 -matplotlib==3.9.2 -matplotlib-inline==0.1.7 -msgpack==1.1.0 -nest-asyncio==1.6.0 -networkx==3.4.2 -numba==0.60.0 -numpy==2.0.2 -packaging==24.1 -pandas==2.2.3 -parso==0.8.4 -partd==1.4.2 -patsy==0.5.6 -pexpect==4.9.0 -pillow==11.0.0 -platformdirs==4.3.6 -prompt_toolkit==3.0.48 -psutil==6.1.0 -ptyprocess==0.7.0 -pure_eval==0.2.3 -Pygments==2.18.0 -pyparsing==3.2.0 -python-dateutil==2.9.0.post0 -pytz==2024.2 -PyYAML==6.0.2 -pyzmq==26.2.0 -regularized-glm==1.0.2 --e git+https://github.com/Eden-Kramer-Lab/replay_trajectory_classification.git@a61ba5c52f368ac28329806c4e081dc51f5f03fb#egg=replay_trajectory_classification -scikit-image==0.24.0 -scikit-learn==1.5.2 -scipy==1.14.1 -seaborn==0.13.2 -setuptools==75.3.0 -six==1.16.0 -sortedcontainers==2.4.0 -stack-data==0.6.3 -statsmodels==0.14.4 -tblib==3.0.0 -threadpoolctl==3.5.0 -tifffile==2024.9.20 -toolz==1.0.0 -tornado==6.4.1 -tqdm==4.66.6 -track_linearization==2.3.2 -traitlets==5.14.3 -tzdata==2024.2 -urllib3==2.2.3 -wcwidth==0.2.13 -xarray==2024.10.0 -zict==3.0.0 From 7f6b75445be2d252a42141f481462265cf9adf1f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 18 Nov 2024 15:36:47 +0000 Subject: [PATCH 28/34] Updated notebooks for new package --- .../notebooks/ClusterlessDecoder.ipynb | 27 +- .../notebooks/SortedSpikesDecoder.ipynb | 1086 ++--------------- 2 files changed, 99 insertions(+), 1014 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb index 01f2ff9..f102110 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb @@ -43,8 +43,9 @@ "outputs": [], "source": [ "positions_df = pd.read_pickle(positions_filename)\n", - "time_start = positions_df.index.to_numpy()[0]\n", - "time_end = positions_df.index.to_numpy()[-1]\n", + "timestamps = positions_df.index.to_numpy()\n", + "time_start = timestamps[0]\n", + "time_end = timestamps[-1]\n", "dt = 0.02\n", "Fs = 1.0 / dt\n", "spikes_bins = np.arange(time_start - dt, time_end + dt, dt)" @@ -111,25 +112,18 @@ "outputs": [], "source": [ "with open(features_filename, \"rb\") as f:\n", - " clusterless_spike_features = pkl.load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " clusterless_spike_features = pkl.load(f)\n", + "\n", "with open(spikes_filename, \"rb\") as f:\n", " clusterless_spike_times = pkl.load(f)\n", "\n", - "features = np.ones((len(spikes_bins - 1), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n", + "features = np.ones((len(spikes_bins) - 1, len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n", "for n in range(len(clusterless_spike_times)):\n", - " in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins)\n", + " in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins) - 1\n", " features[in_spikes_window, :, n] = clusterless_spike_features[n]\n", "\n", - "linear_position = np.ones(len(spikes_bins - 1)) * np.nan\n", - "in_position_window = np.digitize(positions_df.index, spikes_bins)\n", + "linear_position = np.ones(len(spikes_bins) - 1) * np.nan\n", + "in_position_window = np.digitize(positions_df.index, spikes_bins) - 1\n", "linear_position[in_position_window] = linearized_positions.linear_position" ] }, @@ -175,7 +169,7 @@ "print(f\"Saving model to {model_filename}\")\n", "\n", "results = dict(decoder=decoder, linear_position=linear_position,\n", - " clusterless_spike_times=clusterless_spike_times, features=features, Fs=Fs)\n", + " features=features, Fs=Fs)\n", "\n", "with open(model_filename, \"wb\") as f:\n", " pkl.dump(results, f)" @@ -209,7 +203,6 @@ " \n", "decoder = model_results[\"decoder\"]\n", "Fs = model_results[\"Fs\"]\n", - "clusterless_spike_times = model_results[\"clusterless_spike_times\"]\n", "features = model_results[\"features\"]\n", "linear_position = model_results[\"linear_position\"]" ] diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb index 6bdc182..5dde0a4 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb @@ -9,18 +9,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/multiunit_likelihood.py:8: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from tqdm.autonotebook import tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "import sys\n", @@ -46,446 +37,31 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "positions_df = pd.read_pickle(positions_filename)\n", "timestamps = positions_df.index.to_numpy()\n", + "time_start = timestamps[0]\n", + "time_end = timestamps[-1]\n", "dt = timestamps[1] - timestamps[0]\n", "Fs = 1.0 / dt\n", - "spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2)" + "spikes_bins = np.arange(time_start - dt, time_end + dt, dt)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nose_xnose_ynose_veltailBase_xtailBase_ytailBase_veltailMid_xtailMid_ytailMid_veltailTip_x...hindpawR_velforelimb_mid_xforelimb_mid_yforelimb_velbody_dirlinear_positiontrack_segment_idprojected_x_positionprojected_y_positionarm_name
time
22389.0828756.3026285.23117445.32077012.8560645.54749648.29090911.3679136.83302156.50818616.296467...NaN5.0806713.215346125.7070052.338472162.28501936.3451018.359380Left Arm
22389.0848756.7552765.64450345.57971613.5034485.68928648.48375112.1428277.31534556.75125617.387420...NaN5.1946843.302497124.9766022.347791162.73201436.7920558.353311Left Arm
22389.0868757.2079256.05783345.83866114.1508325.83107648.67659312.9177407.79766956.99432518.478373...NaN5.3086973.389648124.2462002.357110163.17901037.2390098.347243Left Arm
22389.0888757.6605736.47116246.09760714.7982155.97286648.86943513.6926548.27999357.23739519.569326...NaN5.4227093.476798123.5157972.366429163.62600537.6859638.341174Left Arm
22389.0908758.1132226.88449246.35655215.4455996.11465749.06227714.4675688.76231757.48046420.660279...NaN5.5367223.563949122.7853942.375748164.07300038.1329178.335106Left Arm
..................................................................
23293.7228750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7248750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7268750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7288750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
23293.7308750.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.00.0000000.0000000.000000-0.000057161.44725835.5074178.370753Left Arm
\n", - "

452325 rows × 33 columns

\n", - "
" - ], - "text/plain": [ - " nose_x nose_y nose_vel tailBase_x tailBase_y \\\n", - "time \n", - "22389.082875 6.302628 5.231174 45.320770 12.856064 5.547496 \n", - "22389.084875 6.755276 5.644503 45.579716 13.503448 5.689286 \n", - "22389.086875 7.207925 6.057833 45.838661 14.150832 5.831076 \n", - "22389.088875 7.660573 6.471162 46.097607 14.798215 5.972866 \n", - "22389.090875 8.113222 6.884492 46.356552 15.445599 6.114657 \n", - "... ... ... ... ... ... \n", - "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 \n", - "\n", - " tailBase_vel tailMid_x tailMid_y tailMid_vel tailTip_x ... \\\n", - "time ... \n", - "22389.082875 48.290909 11.367913 6.833021 56.508186 16.296467 ... \n", - "22389.084875 48.483751 12.142827 7.315345 56.751256 17.387420 ... \n", - "22389.086875 48.676593 12.917740 7.797669 56.994325 18.478373 ... \n", - "22389.088875 48.869435 13.692654 8.279993 57.237395 19.569326 ... \n", - "22389.090875 49.062277 14.467568 8.762317 57.480464 20.660279 ... \n", - "... ... ... ... ... ... ... \n", - "23293.722875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.724875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.726875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.728875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "23293.730875 0.000000 0.000000 0.000000 0.000000 0.000000 ... \n", - "\n", - " hindpawR_vel forelimb_mid_x forelimb_mid_y forelimb_vel \\\n", - "time \n", - "22389.082875 NaN 5.080671 3.215346 125.707005 \n", - "22389.084875 NaN 5.194684 3.302497 124.976602 \n", - "22389.086875 NaN 5.308697 3.389648 124.246200 \n", - "22389.088875 NaN 5.422709 3.476798 123.515797 \n", - "22389.090875 NaN 5.536722 3.563949 122.785394 \n", - "... ... ... ... ... \n", - "23293.722875 0.0 0.000000 0.000000 0.000000 \n", - "23293.724875 0.0 0.000000 0.000000 0.000000 \n", - "23293.726875 0.0 0.000000 0.000000 0.000000 \n", - "23293.728875 0.0 0.000000 0.000000 0.000000 \n", - "23293.730875 0.0 0.000000 0.000000 0.000000 \n", - "\n", - " body_dir linear_position track_segment_id \\\n", - "time \n", - "22389.082875 2.338472 162.285019 3 \n", - "22389.084875 2.347791 162.732014 3 \n", - "22389.086875 2.357110 163.179010 3 \n", - "22389.088875 2.366429 163.626005 3 \n", - "22389.090875 2.375748 164.073000 3 \n", - "... ... ... ... \n", - "23293.722875 -0.000057 161.447258 3 \n", - "23293.724875 -0.000057 161.447258 3 \n", - "23293.726875 -0.000057 161.447258 3 \n", - "23293.728875 -0.000057 161.447258 3 \n", - "23293.730875 -0.000057 161.447258 3 \n", - "\n", - " projected_x_position projected_y_position arm_name \n", - "time \n", - "22389.082875 6.345101 8.359380 Left Arm \n", - "22389.084875 6.792055 8.353311 Left Arm \n", - "22389.086875 7.239009 8.347243 Left Arm \n", - "22389.088875 7.685963 8.341174 Left Arm \n", - "22389.090875 8.132917 8.335106 Left Arm \n", - "... ... ... ... \n", - "23293.722875 5.507417 8.370753 Left Arm \n", - "23293.724875 5.507417 8.370753 Left Arm \n", - "23293.726875 5.507417 8.370753 Left Arm \n", - "23293.728875 5.507417 8.370753 Left Arm \n", - "23293.730875 5.507417 8.370753 Left Arm \n", - "\n", - "[452325 rows x 33 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "positions_df" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -511,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -530,21 +106,25 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(spikes_filename, \"rb\") as f:\n", " sorted_spike_times = pkl.load(f)\n", "\n", - "binned_spikes_times = np.empty((len(timestamps), len(sorted_spike_times)), dtype=float)\n", + "binned_spike_times = np.empty((len(spikes_bins) - 1, len(sorted_spike_times)), dtype=float)\n", "for n in range(len(sorted_spike_times)):\n", - " binned_spikes_times[:, n] = np.histogram(sorted_spike_times[n], spikes_bins)[0]" + " binned_spike_times[:, n] = np.histogram(sorted_spike_times[n], spikes_bins)[0]\n", + "\n", + "linear_position = np.ones(len(spikes_bins) - 1) * np.nan\n", + "in_position_window = np.digitize(positions_df.index, spikes_bins) - 1\n", + "linear_position[in_position_window] = linearized_positions.linear_position" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -566,532 +146,24 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Learning model parameters\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/continuous_state_transitions.py:24: RuntimeWarning: invalid value encountered in divide\n", - " x /= x.sum(axis=1, keepdims=True)\n", - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/spiking_likelihood_kde.py:117: RuntimeWarning: divide by zero encountered in log\n", - " return np.exp(np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy))\n", - "/home/nicholas/replay_trajectory_classification/replay_trajectory_classification/likelihoods/spiking_likelihood_kde.py:117: RuntimeWarning: invalid value encountered in subtract\n", - " return np.exp(np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy))\n" - ] - }, - { - "data": { - "text/html": [ - "
SortedSpikesDecoder(environment=Environment(environment_name='',\n",
-       "                                            place_bin_size=0.5,\n",
-       "                                            track_graph=<networkx.classes.graph.Graph object at 0x7a3ee3555d30>,\n",
-       "                                            edge_order=[(3, 2), (0, 1), (1, 2),\n",
-       "                                                        (5, 4), (4, 2)],\n",
-       "                                            edge_spacing=[16, 0, 16, 0],\n",
-       "                                            is_track_interior=None,\n",
-       "                                            position_range=None,\n",
-       "                                            infer_track_interior=True,\n",
-       "                                            fill_holes=False,\n",
-       "                                            dilate=False,\n",
-       "                                            bin_count_threshold=0),\n",
-       "                    infer_track_interior=True,\n",
-       "                    initial_conditions_type=UniformInitialConditions(),\n",
-       "                    sorted_spikes_algorithm='spiking_likelihood_kde',\n",
-       "                    sorted_spikes_algorithm_params={'block_size': None,\n",
-       "                                                    'position_std': 6.0,\n",
-       "                                                    'use_diffusion': False},\n",
-       "                    transition_type=RandomWalk(environment_name='',\n",
-       "                                               movement_var=0.25,\n",
-       "                                               movement_mean=0.0,\n",
-       "                                               use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" - ], - "text/plain": [ - "SortedSpikesDecoder(environment=Environment(environment_name='',\n", - " place_bin_size=0.5,\n", - " track_graph=,\n", - " edge_order=[(3, 2), (0, 1), (1, 2),\n", - " (5, 4), (4, 2)],\n", - " edge_spacing=[16, 0, 16, 0],\n", - " is_track_interior=None,\n", - " position_range=None,\n", - " infer_track_interior=True,\n", - " fill_holes=False,\n", - " dilate=False,\n", - " bin_count_threshold=0),\n", - " infer_track_interior=True,\n", - " initial_conditions_type=UniformInitialConditions(),\n", - " sorted_spikes_algorithm='spiking_likelihood_kde',\n", - " sorted_spikes_algorithm_params={'block_size': None,\n", - " 'position_std': 6.0,\n", - " 'use_diffusion': False},\n", - " transition_type=RandomWalk(environment_name='',\n", - " movement_var=0.25,\n", - " movement_mean=0.0,\n", - " use_diffusion=False))" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "print(\"Learning model parameters\")\n", - "decoder.fit(linearized_positions.linear_position, binned_spikes_times)" + "decoder.fit(linear_position, binned_spike_times)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Saving model to ../../../../datasets/decoder_data/sorted_spike_model.pkl\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"Saving model to {model_filename}\")\n", "\n", - "results = dict(decoder=decoder, linearized_positions=linearized_positions,\n", - " binned_spikes_times=binned_spikes_times, Fs=Fs)\n", + "results = dict(decoder=decoder, linear_position=linear_position,\n", + " spike_times=binned_spike_times, Fs=Fs)\n", "\n", "with open(model_filename, \"wb\") as f:\n", " pkl.dump(results, f)" @@ -1116,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1125,73 +197,93 @@ " \n", "decoder = model_results[\"decoder\"]\n", "Fs = model_results[\"Fs\"]\n", - "binned_spikes_times = model_results[\"binned_spikes_times\"]\n", - "linearized_positions = model_results[\"linearized_positions\"]" + "spike_times = model_results[\"spike_times\"]\n", + "linear_position = model_results[\"linear_position\"]" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Decoding positions from spikes\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/104 [00:00 Date: Mon, 18 Nov 2024 18:41:45 +0000 Subject: [PATCH 29/34] Added new dataset to README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 092c553..59a475f 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,11 @@ All of the datasets used in these examples can be found by going to: [https://do ZebrafishExampleVid.avi - provided by Nicholas Guilbeault in the Thiele lab at the University of Toronto. If you would like to refer to this data, please cite Guilbeault, N.C., Guerguiev, J., Martin, M. et al. (2021). BonZeb: open-source, modular software tools for high-resolution zebrafish tracking and analysis. *Scientific Reports* *11*, 8148, [https://doi.org/10.1038/s41598-021-85896-x](https://doi.org/10.1038/s41598-021-85896-x). -ForagingMouseExampleVid.avi - provided by the Sainsbury Wellcome Centre Foraging Behaviour Working Group. (2023). Aeon: An open-source platform to study the neural basis of ethological behaviours over naturalistic timescales, [https://doi.org/10.5281/zenodo.8413142](https://doi.org/10.5281/zenodo.8413142) +ForagingMouseExampleVid.avi - provided by the Sainsbury Wellcome Centre Foraging Behaviour Working Group. (2023). Aeon: An open-source platform to study the neural basis of ethological behaviours over naturalistic timescales, [https://doi.org/10.5281/zenodo.8413142](https://doi.org/10.5281/zenodo.8413142). -ReceptiveFieldSimpleCell.zip - provided by the authors of "Touryan, J., Felsen, G., & Dan, Y. (2005). Spatial structure of complex cell receptive fields measured with natural images. Neuron, 45(5), 781-791." [https://doi.org/10.1016/j.neuron.2005.01.029](https://doi.org/10.1016/j.neuron.2005.01.029) +ReceptiveFieldSimpleCell.zip - provided by the authors of "Touryan, J., Felsen, G., & Dan, Y. (2005). Spatial structure of complex cell receptive fields measured with natural images. Neuron, 45(5), 781-791." [https://doi.org/10.1016/j.neuron.2005.01.029](https://doi.org/10.1016/j.neuron.2005.01.029). + +HippocampalTetrodeRecording.zip - provided by the authors of Joshi, A., Denovellis, E.L., Mankili, A. et al. (2023). Dynamic synchronization between hippocampal representations and stepping. Nature 617, 125–131. [https://doi.org/10.1038/s41586-023-05928-6](https://doi.org/10.1038/s41586-023-05928-6). ### Acknowledgements From 1060fee7f8a1b8fb4e7835fdcde0ddf87c133066 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 18 Nov 2024 18:45:49 +0000 Subject: [PATCH 30/34] Updated example README --- .../NeuralDecoding/PositionDecodingFromHippocampus/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md index 771ccb3..c8edcb4 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md @@ -6,6 +6,8 @@ In the following example, you can find how to use the spike sorted decoder or cl We thank Eric Denovellis for sharing his data and for his help with the decoder. If you use this example dataset, please consider citing the work: Joshi, A., Denovellis, E.L., Mankili, A. et al. Dynamic synchronization between hippocampal representations and stepping. Nature 617, 125–131 (2023). https://doi.org/10.1038/s41586-023-05928-6. +You can download the HippocampalTetrodeRecordings.zip file here: https://doi.org/10.5281/zenodo.10629221. The workflow expects the data to be placed into the datasets folder. It should be structured like this: `/path/to/machinelearning-examples/datasets/decoder_data/*` + ## Algorithm The neural decoder consists of a bayesian state-space model and point-processes to decode a latent variable (position) from neural spiking activity. To read more about the theory behind the model and how the algorithm works, we refer the reader to: Denovellis, E.L., Gillespie, A.K., Coulter, M.E., et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021). https://doi.org/10.7554/eLife.64505. From 1043f167881f798090fa61e650132451da90fce8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 18 Nov 2024 18:46:21 +0000 Subject: [PATCH 31/34] Removed python package from inside example folder --- .../decoder/__init__.py | 4 - .../decoder/data_iterator.py | 25 --- .../decoder/data_loader.py | 98 ----------- .../decoder/decoder.py | 151 ----------------- .../decoder/likelihood.py | 159 ------------------ .../scripts/clusterless_spike_decoder.py | 138 --------------- .../scripts/sorted_spike_decoder.py | 129 -------------- 7 files changed, 704 deletions(-) delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_iterator.py delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/clusterless_spike_decoder.py delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/sorted_spike_decoder.py diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py deleted file mode 100644 index d48ce44..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .decoder import SortedSpikeDecoder, ClusterlessSpikeDecoder -from .data_loader import DataLoader -from .data_iterator import DataIterator -from .likelihood import LIKELIHOOD_FUNCTION \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_iterator.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_iterator.py deleted file mode 100644 index 98c5eb1..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_iterator.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np - -class DataIterator: - - def __init__(self, - data: dict, - start_index: int = 0): - self.data = data - self.keys = list(self.data.keys()) - self.index = start_index - super().__init__() - - def next(self) -> tuple[list, list]: - - output = [] - - for key in self.keys: - try: - output.append(self.data[key][self.index]) - except IndexError: - output.append(self.data[key][0]) - - self.index += 1 - - return (output, self.keys) \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py deleted file mode 100644 index 5aafca2..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/data_loader.py +++ /dev/null @@ -1,98 +0,0 @@ -import pickle -import pandas as pd -import numpy as np -import track_linearization as tl -import replay_trajectory_classification as rtc -import os - -class DataLoader: - def __init__(self): - super().__init__() - - @classmethod - def load_sorted_spike_data(cls, - dataset_path: str = "../../../datasets/decoder_data", - bin_spikes: bool = True) -> dict: - - if len([file for file in os.listdir(dataset_path) if file == "position_info.pkl" or file == "sorted_spike_times.pkl" or file == "sorted_spike_decoding_results.pkl"]) != 3: - raise Exception("Dataset incorrect. Missing at least one of the following files: 'position_info.pkl', 'sorted_spike_times.pkl', 'sorted_spike_decoding_results.pkl'") - - position_data = pd.read_pickle(os.path.join(dataset_path, "position_info.pkl")) - position_index = position_data.index.to_numpy() - position_index = np.insert(position_index, 0, position_index[0] - (position_index[1] - position_index[0])) - position_data = position_data[["nose_x", "nose_y"]].to_numpy() - - node_positions = [(120.0, 100.0), - ( 5.0, 100.0), - ( 5.0, 55.0), - (120.0, 55.0), - ( 5.0, 8.5), - (120.0, 8.5), - ] - edges = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - track_graph = rtc.make_track_graph(node_positions, edges) - - edge_order = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - - edge_spacing = [16, 0, 16, 0] - - linearized_positions = tl.get_linearized_position(position_data, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False) - position_data = linearized_positions.linear_position - - with open(os.path.join(dataset_path, "sorted_spike_times.pkl"), "rb") as f: - spike_times = pickle.load(f) - - if bin_spikes: - spike_mat = np.zeros((len(position_data), len(spike_times))) - for neuron in range(len(spike_times)): - spike_mat[:, neuron] = np.histogram(spike_times[neuron], position_index)[0] - spike_times = spike_mat - - with open(os.path.join(dataset_path, "sorted_spike_decoding_results.pkl"), "rb") as f: - results = pickle.load(f)["decoding_results"] - position_bins = results.position.to_numpy()[np.newaxis] - decoding_results = results.acausal_posterior.to_numpy()[:,np.newaxis] - - return { - "position_data": position_data, - "spike_times": spike_times, - "decoding_results": decoding_results, - "position_bins": position_bins - } - - @classmethod - def load_clusterless_spike_data(cls, - dataset_path: str = "../../../datasets/decoder_data") -> dict: - - decoding_results_filename = os.path.join(dataset_path, "clusterless_spike_decoding_results_50Hz.pkl") - if not os.path.exists(decoding_results_filename): - raise Exception("Dataset incorrect. Missing 'clusterless_spike_decoding_results_50Hz.pkl'") - - with open(decoding_results_filename, "rb") as f: - results = pickle.load(f) - decoding_results = results["decoding_results"] - position_bins = decoding_results.position.to_numpy()[np.newaxis] - decoding_results = decoding_results.acausal_posterior.to_numpy()[:,np.newaxis] - features = results["features"] - time = results["time"] - linear_position = results["linear_position"] - - return { - "linear_position": linear_position, - "time": time, - "features": features, - "decoding_results": decoding_results, - "position_bins": position_bins - } diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py deleted file mode 100644 index f4bb48f..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/decoder.py +++ /dev/null @@ -1,151 +0,0 @@ -import replay_trajectory_classification as rtc -from replay_trajectory_classification.core import scaled_likelihood, get_centers -from replay_trajectory_classification.likelihoods import _SORTED_SPIKES_ALGORITHMS, _ClUSTERLESS_ALGORITHMS -from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import combined_likelihood, poisson_log_likelihood -from replay_trajectory_classification.likelihoods.multiunit_likelihood import estimate_position_distance, estimate_log_joint_mark_intensity - -from .likelihood import LIKELIHOOD_FUNCTION - -import numpy as np -import pickle as pkl -import cupy as cp - -class Decoder(): - def __init__(self): - super().__init__() - - def decode(self, data: np.ndarray): - raise NotImplementedError - - @classmethod - def load(cls, filename: str): - with open(filename, "rb") as f: - return cls(pkl.load(f)) - -class ClusterlessSpikeDecoder(Decoder): - def __init__(self, model_dict: dict): - self.decoder = model_dict["decoder"] - self.Fs = model_dict["Fs"] - self.features = model_dict["features"] - - encoding_model = self.decoder.encoding_model_ - self.encoding_marks = encoding_model["encoding_marks"] - self.mark_std = encoding_model["mark_std"] - self.encoding_positions = encoding_model["encoding_positions"] - self.position_std = encoding_model["position_std"] - self.occupancy = encoding_model["occupancy"] - self.mean_rates = encoding_model["mean_rates"] - self.summed_ground_process_intensity = encoding_model["summed_ground_process_intensity"] - self.block_size = encoding_model["block_size"] - self.bin_diffusion_distances = encoding_model["bin_diffusion_distances"] - self.edges = encoding_model["edges"] - - self.place_bin_centers = self.decoder.environment.place_bin_centers_ - self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order="F") - self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior) - - self.likelihood_function = LIKELIHOOD_FUNCTION[self.decoder.clusterless_algorithm] - - if "gpu" in self.decoder.clusterless_algorithm: - self.is_track_interior_gpu = cp.asarray(self.is_track_interior) - self.occupancy = cp.asarray(self.occupancy) - self.interior_place_bin_centers = cp.asarray( - self.place_bin_centers[self.is_track_interior], dtype=cp.float32 - ) - self.interior_occupancy = cp.asarray( - self.occupancy[self.is_track_interior_gpu], dtype=cp.float32 - ) - - else: - self.is_track_interior_gpu = None - self.interior_place_bin_centers = np.asarray( - self.place_bin_centers[self.is_track_interior], dtype=np.float32 - ) - self.interior_occupancy = np.asarray( - self.occupancy[self.is_track_interior], dtype=np.float32 - ) - - self.n_position_bins = self.is_track_interior.shape[0] - self.n_track_bins = self.is_track_interior.sum() - - self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) - self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) - - self.posterior = None - super().__init__() - - @classmethod - def load(cls, filename: str = "../../../datasets/decoder_data/clusterless_spike_decoder.pkl"): - return super().load(filename) - - def decode(self, - data: np.ndarray): - - likelihood = self.likelihood_function( - data, - self.summed_ground_process_intensity, - self.encoding_marks, - self.encoding_positions, - self.mean_rates, - self.is_track_interior, - self.interior_place_bin_centers, - self.position_std, - self.mark_std, - self.interior_occupancy, - self.n_track_bins - ) - - if self.posterior is None: - self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) - self.posterior[0, self.is_track_interior] = self.initial_conditions * likelihood[0, self.is_track_interior] - - else: - self.posterior[0, self.is_track_interior] = self.state_transition.T @ self.posterior[0, self.is_track_interior] * likelihood[0, self.is_track_interior] - - norm = np.nansum(self.posterior[0]) - self.posterior[0] /= norm - - return self.posterior - -class SortedSpikeDecoder(Decoder): - def __init__(self, model_dict: dict): - self.decoder = model_dict["decoder"] - self.Fs = model_dict["Fs"] - self.spikes = model_dict["binned_spikes_times"] - self.is_track_interior = self.decoder.environment.is_track_interior_.ravel(order="F") - self.st_interior_ind = np.ix_(self.is_track_interior, self.is_track_interior) - self.n_position_bins = self.is_track_interior.shape[0] - - self.initial_conditions = self.decoder.initial_conditions_[self.is_track_interior].astype(float) - self.state_transition = self.decoder.state_transition_[self.st_interior_ind].astype(float) - self.place_fields = np.asarray(self.decoder.place_fields_) - self.position_centers = get_centers(self.decoder.environment.edges_[0]) - self.conditional_intensity = np.clip(self.place_fields, a_min=1e-15, a_max=None) - - self.likelihood_function = LIKELIHOOD_FUNCTION[self.decoder.sorted_spikes_algorithm] - - self.posterior = None - super().__init__() - - @classmethod - def load(cls, filename: str = "../../../datasets/decoder_data/sorted_spike_decoder.pkl"): - return super().load(filename) - - def decode( - self, - data: np.ndarray - ): - - likelihood = self.likelihood_function(data, self.conditional_intensity, self.is_track_interior) - - if self.posterior is None: - self.posterior = np.full((1, self.n_position_bins), np.nan, dtype=float) - self.posterior[0, self.is_track_interior] = self.initial_conditions * likelihood[0] - - else: - self.posterior[0, self.is_track_interior] = self.state_transition.T @ self.posterior[0, self.is_track_interior] * likelihood[0] - - norm = np.nansum(self.posterior[0]) - self.posterior[0] /= norm - - return self.posterior \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py deleted file mode 100644 index ed71cd7..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/decoder/likelihood.py +++ /dev/null @@ -1,159 +0,0 @@ -from replay_trajectory_classification.core import scaled_likelihood -import replay_trajectory_classification.likelihoods.multiunit_likelihood as ml -import replay_trajectory_classification.likelihoods.multiunit_likelihood_gpu as mlgpu -import replay_trajectory_classification.likelihoods.multiunit_likelihood_integer_gpu as mligpu -from replay_trajectory_classification.likelihoods.spiking_likelihood_kde import poisson_log_likelihood - -import numpy as np -import cupy as cp - -def spiking_likelihood_kde(spikes, conditional_intensity, is_track_interior): - - log_likelihood = 0 - for spike, ci in zip(spikes, conditional_intensity.T): - log_likelihood += poisson_log_likelihood(spike[np.newaxis], ci) - - mask = np.ones_like(is_track_interior, dtype=float) - mask[~is_track_interior] = np.nan - - likelihood = scaled_likelihood(log_likelihood * mask) - likelihood = likelihood[:, is_track_interior].astype(float) - - return likelihood - -def multiunit_likelihood(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): - log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) - - if not np.isnan(multiunits).all(): - # multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] - - for multiunit, enc_marks, enc_pos, mean_rate in zip( - multiunits.T, - encoding_marks, - encoding_positions, - mean_rates, - ): - is_spike = np.any(~np.isnan(multiunit)) - if is_spike: - decoding_marks = np.asarray( - multiunit, dtype=np.float32 - )[np.newaxis] - log_joint_mark_intensity = np.zeros( - (1, n_track_bins), dtype=np.float32 - ) - position_distance = ml.estimate_position_distance( - interior_place_bin_centers, - np.asarray(enc_pos, dtype=np.float32), - position_std, - ).astype(np.float32) - log_joint_mark_intensity[0] = ml.estimate_log_joint_mark_intensity( - decoding_marks, - enc_marks, - mark_std, - interior_occupancy, - mean_rate, - position_distance=position_distance, - ) - log_likelihood[:, is_track_interior] += np.nan_to_num( - log_joint_mark_intensity - ) - - log_likelihood[:, ~is_track_interior] = np.nan - likelihood = scaled_likelihood(log_likelihood) - - return likelihood - -def multiunit_likelihood_gpu(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): - log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) - - if not np.isnan(multiunits).all(): - # multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] - - for multiunit, enc_marks, enc_pos, mean_rate in zip( - multiunits.T, - encoding_marks, - encoding_positions, - mean_rates, - ): - is_spike = np.any(~np.isnan(multiunit)) - if is_spike: - decoding_marks = cp.asarray( - multiunit, dtype=cp.float32 - )[cp.newaxis] - log_joint_mark_intensity = np.zeros( - (1, n_track_bins), dtype=np.float32 - ) - position_distance = mlgpu.estimate_position_distance( - interior_place_bin_centers, - cp.asarray(enc_pos, dtype=cp.float32), - position_std, - ).astype(cp.float32) - log_joint_mark_intensity[0] = mlgpu.estimate_log_joint_mark_intensity( - decoding_marks, - enc_marks, - mark_std, - interior_occupancy, - mean_rate, - position_distance=position_distance, - ) - log_likelihood[:, is_track_interior] += np.nan_to_num( - log_joint_mark_intensity - ) - - mempool = cp.get_default_memory_pool() - mempool.free_all_blocks() - - log_likelihood[:, ~is_track_interior] = np.nan - likelihood = scaled_likelihood(log_likelihood) - return likelihood - -def multiunit_likelihood_integer_gpu(multiunits, summed_ground_process_intensity, encoding_marks, encoding_positions, mean_rates, is_track_interior, interior_place_bin_centers, position_std, mark_std, interior_occupancy, n_track_bins): - log_likelihood = -summed_ground_process_intensity * np.ones((1,1), dtype=np.float32) - - if not np.isnan(multiunits).all(): - # multiunit_idxs = np.where(~np.isnan(multiunits, axis=0))[0] - - for multiunit, enc_marks, enc_pos, mean_rate in zip( - multiunits.T, - encoding_marks, - encoding_positions, - mean_rates, - ): - is_spike = np.any(~np.isnan(multiunit)) - if is_spike: - decoding_marks = cp.asarray( - multiunit, dtype=cp.int16 - )[cp.newaxis] - log_joint_mark_intensity = np.zeros( - (1, n_track_bins), dtype=np.float32 - ) - position_distance = mligpu.estimate_position_distance( - interior_place_bin_centers, - cp.asarray(enc_pos, dtype=cp.float32), - position_std, - ).astype(cp.float32) - log_joint_mark_intensity[0] = mligpu.estimate_log_joint_mark_intensity( - decoding_marks, - enc_marks, - mark_std, - interior_occupancy, - mean_rate, - position_distance=position_distance, - ) - log_likelihood[:, is_track_interior] += np.nan_to_num( - log_joint_mark_intensity - ) - - mempool = cp.get_default_memory_pool() - mempool.free_all_blocks() - - log_likelihood[:, ~is_track_interior] = np.nan - likelihood = scaled_likelihood(log_likelihood) - return likelihood - -LIKELIHOOD_FUNCTION = { - "multiunit_likelihood": multiunit_likelihood, - "spiking_likelihood_kde": spiking_likelihood_kde, - "multiunit_likelihood_gpu": multiunit_likelihood_gpu, - "multiunit_likelihood_integer_gpu": multiunit_likelihood_integer_gpu -} \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/clusterless_spike_decoder.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/clusterless_spike_decoder.py deleted file mode 100644 index 0fc3e05..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/clusterless_spike_decoder.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -import sys -import pandas as pd -import matplotlib.pyplot as plt -import pickle as pkl - -import replay_trajectory_classification as rtc -import track_linearization as tl - -import argparse - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--positions_filename", type=str, - default="../../../../datasets/decoder_data/position_info.pkl", - help="positions filename") - parser.add_argument("--spike_times_filename", type=str, - default="../../../../datasets/decoder_data/clusterless_spike_times.pkl", - help="spikes filename") - parser.add_argument("--features_filename", type=str, - default="../../../../datasets/decoder_data/clusterless_spike_features.pkl", - help="features filename") - parser.add_argument("--model_filename", type=str, - default="../../../../datasets/decoder_data/clusterless_spike_decoder.pkl", - help="model filename") - parser.add_argument("--decoding_results_filename", type=str, - default="../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl", - help="decoding results filename") - args = parser.parse_args() - - positions_filename = args.positions_filename - spikes_filename = args.spikes_filename - features_filename = args.features_filename - model_filename = args.model_filename - decoding_results_filename = args.decoding_results_filename - - positions_df = pd.read_pickle(positions_filename) - timestamps = positions_df.index.to_numpy() - dt = timestamps[1] - timestamps[0] - Fs = 1.0 / dt - spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2) - - x = positions_df["nose_x"].to_numpy() - y = positions_df["nose_y"].to_numpy() - positions = np.column_stack((x, y)) - node_positions = [(120.0, 100.0), - ( 5.0, 100.0), - ( 5.0, 55.0), - (120.0, 55.0), - ( 5.0, 8.5), - (120.0, 8.5), - ] - edges = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - track_graph = rtc.make_track_graph(node_positions, edges) - - edge_order = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - - edge_spacing = [16, 0, 16, 0] - - linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False) - - with open(features_filename, "rb") as f: - clusterless_spike_features = pkl.load(f) - - with open(spikes_filename, "rb") as f: - clusterless_spike_times = pkl.load(f) - - features = np.ones((len(timestamps), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan - for n in range(len(clusterless_spike_times)): - in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins) - features[in_spikes_window, :, n] = clusterless_spike_features[n] - - place_bin_size = 0.5 - movement_var = 0.25 - - environment = rtc.Environment(place_bin_size=place_bin_size, - track_graph=track_graph, - edge_order=edge_order, - edge_spacing=edge_spacing) - - transition_type = rtc.RandomWalk(movement_var=movement_var) - - decoder = rtc.ClusterlessDecoder( - environment=environment, - transition_type=transition_type, - clusterless_algorithm="multiunit_likelihood" - ) - - print("Learning model parameters") - decoder.fit(linearized_positions.linear_position, features) - - print(f"Saving model to {model_filename}") - - results = dict(decoder=decoder, linearized_positions=linearized_positions, - clusterless_spike_times=clusterless_spike_times, features=features, Fs=Fs) - - with open(model_filename, "wb") as f: - pkl.dump(results, f) - - decoding_start_secs = 0 - decoding_duration_secs = 100 - - with open(model_filename, "rb") as f: - model_results = pkl.load(f) - - decoder = model_results["decoder"] - Fs = model_results["Fs"] - clusterless_spike_times = model_results["clusterless_spike_times"] - features = model_results["features"] - linearized_positions = model_results["linearized_positions"] - - print("Decoding positions from features") - decoding_start_samples = int(decoding_start_secs * Fs) - decoding_duration_samples = int(decoding_duration_secs * Fs) - time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples) - time = np.arange(linearized_positions.linear_position.size) / Fs - decoding_results = decoder.predict(features[time_ind], time=time[time_ind]) - - print(f"Saving decoding results to {decoding_results_filename}") - - results = dict(decoding_results=decoding_results, time=time[time_ind], - linearized_positions=linearized_positions.iloc[time_ind], - features=features[time_ind]) - - with open(decoding_results_filename, "wb") as f: - pkl.dump(results, f) \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/sorted_spike_decoder.py b/examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/sorted_spike_decoder.py deleted file mode 100644 index 5733b92..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/scripts/sorted_spike_decoder.py +++ /dev/null @@ -1,129 +0,0 @@ -import numpy as np -import sys -import pandas as pd -import matplotlib.pyplot as plt -import pickle as pkl - -import replay_trajectory_classification as rtc -import track_linearization as tl - -import argparse - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--positions_filename", type=str, - default="../../../../datasets/decoder_data/position_info.pkl", - help="positions filename") - parser.add_argument("--spike_times_filename", type=str, - default="../../../../datasets/decoder_data/sorted_spike_times.pkl", - help="spike times filename") - parser.add_argument("--model_filename", type=str, - default="../../../../datasets/decoder_data/sorted_spike_decoder.pkl", - help="model filename") - parser.add_argument("--decoding_results_filename", type=str, - default="../../../../datasets/decoder_data/sorted_spike_decoding_results.pkl", - help="decoding results filename") - args = parser.parse_args() - - positions_filename = args.positions_filename - spikes_filename = args.spikes_filename - features_filename = args.features_filename - model_filename = args.model_filename - decoding_results_filename = args.decoding_results_filename - - positions_df = pd.read_pickle(positions_filename) - timestamps = positions_df.index.to_numpy() - dt = timestamps[1] - timestamps[0] - Fs = 1.0 / dt - spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2) - - x = positions_df["nose_x"].to_numpy() - y = positions_df["nose_y"].to_numpy() - positions = np.column_stack((x, y)) - node_positions = [(120.0, 100.0), - ( 5.0, 100.0), - ( 5.0, 55.0), - (120.0, 55.0), - ( 5.0, 8.5), - (120.0, 8.5), - ] - edges = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - track_graph = rtc.make_track_graph(node_positions, edges) - - edge_order = [ - (3, 2), - (0, 1), - (1, 2), - (5, 4), - (4, 2), - ] - - edge_spacing = [16, 0, 16, 0] - - linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False) - - with open(spikes_filename, "rb") as f: - sorted_spike_times = pkl.load(f) - - binned_spike_times = np.empty((len(timestamps), len(sorted_spike_times)), dtype=float) - for n in range(len(sorted_spike_times)): - binned_spike_times[:, n] = np.histogram(sorted_spike_times[n], spikes_bins)[0] - - place_bin_size = 0.5 - movement_var = 0.25 - - environment = rtc.Environment(place_bin_size=place_bin_size, - track_graph=track_graph, - edge_order=edge_order, - edge_spacing=edge_spacing) - - transition_type = rtc.RandomWalk(movement_var=movement_var) - - decoder = rtc.SortedSpikesDecoder( - environment=environment, - transition_type=transition_type, - ) - - print("Learning model parameters") - decoder.fit(linearized_positions.linear_position, binned_spike_times) - - print(f"Saving model to {model_filename}") - - results = dict(decoder=decoder, linearized_positions=linearized_positions, - binned_spike_times=binned_spike_times, Fs=Fs) - - with open(model_filename, "wb") as f: - pkl.dump(results, f) - - decoding_start_secs = 0 - decoding_duration_secs = 100 - - with open(model_filename, "rb") as f: - model_results = pkl.load(f) - - decoder = model_results["decoder"] - Fs = model_results["Fs"] - binned_spike_times = model_results["binned_spike_times"] - linearized_positions = model_results["linearized_positions"] - - print("Decoding positions from spikes") - decoding_start_samples = int(decoding_start_secs * Fs) - decoding_duration_samples = int(decoding_duration_secs * Fs) - time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples) - time = np.arange(linearized_positions.linear_position.size) / Fs - decoding_results = decoder.predict(binned_spike_times[time_ind], time=time[time_ind]) - - print(f"Saving decoding results to {decoding_results_filename}") - - results = dict(decoding_results=decoding_results, time=time[time_ind], - linearized_positions=linearized_positions.iloc[time_ind], - binned_spike_times=binned_spike_times[time_ind]) - - with open(decoding_results_filename, "wb") as f: - pkl.dump(results, f) \ No newline at end of file From ee9def48607e47a542f1c43d643bc2fa2a993775 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 19 Nov 2024 12:09:19 +0000 Subject: [PATCH 32/34] Renamed workflows --- .../ClusterlessSpikes.bonsai | 283 ------------------ .../DecodeClusterlessSpikes.bonsai | 77 +++++ .../DecodeSortedSpikes.bonsai | 77 +++++ .../SortedSpikes.bonsai | 282 ----------------- 4 files changed, 154 insertions(+), 565 deletions(-) delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai create mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai delete mode 100644 examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai deleted file mode 100644 index 19878c1..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/ClusterlessSpikes.bonsai +++ /dev/null @@ -1,283 +0,0 @@ - - - - - - - - - - - ImportDecoderLibrary - - - - Source1 - - - - import sys -import os -sys.path.append(os.getcwd()) -sys.path.append(os.path.join(os.getcwd(), ".venv/Scripts")) -from decoder import * - - - - - - - - - - - - LoadDecoder - - - - - - - - decoder = ClusterlessSpikeDecoder.load("../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int_50Hz.pkl") - - - - - - - - - - - - LoadData - - - - - - - - data = DataLoader.load_clusterless_spike_data() - - - - - - - - - - - - - - - IterateData - - - - - - - - PT0S - PT0.01S - - - - - - - - iterator = DataIterator(data) - - - - - - - - - - - iterator.next() - - - - it[0] - - - new(it[0] as Position, it[1] as Features, it[2] as Decoding, it[3] as PositionBins) - - - - - - - - - - - - - - - - - - Data - - - Performance - - - - Source1 - - - - - - Interval.TotalMilliseconds - - - 1000/it - - - - 10 - 1 - - - - - - - Source1 - - - - 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - RunDecoder - - - - Data - - - Features - - - - spikes - - - - - decoder.decode(spikes) - - - - - 0 - - - - OnlineDecoderResults - - - Data - - - Position - - - it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString()) - - - Position - - - - - - - - - - - - - - - - VisualizeDecoder - - - - Position - - - - OnlineDecoderResults - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai new file mode 100644 index 0000000..dc0894b --- /dev/null +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai @@ -0,0 +1,77 @@ + + + + + + + + + + + + decoder + ../../../datasets/decoder_data/clusterless_spike_decoder.pkl + + + ../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl + PT0.1S + + + Data + + + Data + + + Spikes + + + decoder + + + OnlineDecodedResults + + + VisualizeDecoder + + + + Data + + + Position + + + ConvertNaN + it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString()) + + + + OnlineDecodedResults + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai new file mode 100644 index 0000000..20e1307 --- /dev/null +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai @@ -0,0 +1,77 @@ + + + + + + + + + + + + decoder + ../../../datasets/decoder_data/sorted_spike_decoder.pkl + + + ../../../datasets/decoder_data/sorted_spike_decoding_results.pkl + PT0.01S + + + Data + + + Data + + + Spikes + + + decoder + + + OnlineDecodedResults + + + VisualizeDecoder + + + + Data + + + Position + + + ConvertNaN + it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString()) + + + + OnlineDecodedResults + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai deleted file mode 100644 index cdaaaec..0000000 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/SortedSpikes.bonsai +++ /dev/null @@ -1,282 +0,0 @@ - - - - - - - - - - - ImportDecoderLibrary - - - - Source1 - - - - import sys -import os -sys.path.append(os.getcwd()) -from decoder import * - - - - - - - - - - - - LoadDecoder - - - - - - - - model = SortedSpikeDecoder.load() - - - - - - - - - - - - LoadData - - - - - - - - data = DataLoader.load_sorted_spike_data() - - - - - - - - - - - - - - - IterateData - - - - - - - - PT0S - PT0.01S - - - - - - - - iterator = DataIterator(data) - - - - - - - - - - - iterator.next() - - - - it[0] - - - new(it[0] as Position, it[1] as Spikes, it[2] as Decoding, it[3] as PositionBins) - - - - - - - - - - - - - - - - - - Data - - - Performance - - - - Source1 - - - - - - Interval.TotalMilliseconds - - - 1000/it - - - - 10 - 1 - - - - - - - Source1 - - - - 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - RunDecoder - - - - Data - - - Spikes - - - - spikes - - - - - model.decode(spikes) - - - - - 0 - - - - OnlineDecoderResults - - - Data - - - Position - - - it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString()) - - - Position - - - - - - - - - - - - - - - - VisualizeDecoder - - - - Position - - - - OnlineDecoderResults - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file From d3bd5789cda219e68a4b832f75d512c029a44105 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 19 Nov 2024 12:10:14 +0000 Subject: [PATCH 33/34] Updated notebooks and removed redundant loading from file --- .../PositionDecodingFromHippocampus/README.md | 2 +- ...er.ipynb => ClusterlessSpikeDecoder.ipynb} | 49 +++++-------------- ...Decoder.ipynb => SortedSpikeDecoder.ipynb} | 47 +++++------------- 3 files changed, 26 insertions(+), 72 deletions(-) rename examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/{ClusterlessDecoder.ipynb => ClusterlessSpikeDecoder.ipynb} (85%) rename examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/{SortedSpikesDecoder.ipynb => SortedSpikeDecoder.ipynb} (85%) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md index c8edcb4..30f08dc 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md @@ -48,7 +48,7 @@ Alternatively, you can copy the `.bonsai\Bonsai.config` file into your Bonsai in The package contains 2 different models, one which takes as input sorted spike activity, and another which uses clusterless spike activity taken from raw ephys recordings. -You first need to train the decoder model and save it to disk. Open up the `notebooks` folder and select either `SortedSpikeDecoder.ipynb` or `ClusterlessDecoder.ipynb` depending on which model type you would like to use. Run the notebook. Once completed, this will create 2 new files: +You first need to train the decoder model and save it to disk. Open up the `notebooks` folder and select either `SortedSpikeDecoder.ipynb` or `ClusterlessSpikeDecoder.ipynb` depending on which model type you would like to use. Run the notebook. Once completed, this will create 2 new files: 1) `datasets\decoder_data\[ModelType]_decoder.pkl` for the trained decoder model 2) `datasets\decoder_data\[ModelType]_decoding_results.pkl` for the predictions. diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessSpikeDecoder.ipynb similarity index 85% rename from examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessSpikeDecoder.ipynb index f102110..2561f7a 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessSpikeDecoder.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Learn" + "# Train the decoder and save the results" ] }, { @@ -14,9 +14,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "import sys\n", "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", "import pickle as pkl\n", "\n", "import replay_trajectory_classification as rtc\n", @@ -32,8 +30,8 @@ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n", "spikes_filename = \"../../../../datasets/decoder_data/clusterless_spike_times.pkl\"\n", "features_filename = \"../../../../datasets/decoder_data/clusterless_spike_features.pkl\"\n", - "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int_50Hz.pkl\"\n", - "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results_50Hz.pkl\"" + "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder.pkl\"\n", + "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl\"" ] }, { @@ -168,8 +166,7 @@ "source": [ "print(f\"Saving model to {model_filename}\")\n", "\n", - "results = dict(decoder=decoder, linear_position=linear_position,\n", - " features=features, Fs=Fs)\n", + "results = dict(decoder=decoder)\n", "\n", "with open(model_filename, \"wb\") as f:\n", " pkl.dump(results, f)" @@ -192,21 +189,6 @@ "decoding_duration_secs = 100" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(model_filename, \"rb\") as f:\n", - " model_results = pkl.load(f)\n", - " \n", - "decoder = model_results[\"decoder\"]\n", - "Fs = model_results[\"Fs\"]\n", - "features = model_results[\"features\"]\n", - "linear_position = model_results[\"linear_position\"]" - ] - }, { "cell_type": "code", "execution_count": null, @@ -227,23 +209,23 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"Saving decoding results to {decoding_filename}\")\n", + "print(f\"Saving decoded results to {decoding_filename}\")\n", "\n", - "results = dict(decoding_results=decoding_results, time=time[time_ind],\n", + "results = dict(decoding_results=decoding_results,\n", " linear_position=linear_position[time_ind],\n", - " features=features[time_ind])\n", + " spikes=features[time_ind])\n", "\n", "with open(decoding_filename, \"wb\") as f:\n", " pkl.dump(results, f)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import plotly.graph_objects as go" + "## Optional\n", + "\n", + "Plot the decoded results" ] }, { @@ -252,12 +234,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open(decoding_filename, \"rb\") as f:\n", - " load_res = pkl.load(f)\n", - "\n", - "decoding_results = load_res[\"decoding_results\"]\n", - "time = load_res[\"time\"]\n", - "linear_position = load_res[\"linear_position\"]" + "import plotly.graph_objects as go" ] }, { @@ -274,7 +251,7 @@ " zmin=0.00, zmax=0.05, showscale=False)\n", "fig.add_trace(trace)\n", "\n", - "trace = go.Scatter(x=time, y=linear_position,\n", + "trace = go.Scatter(x=time[time_ind], y=linear_position,\n", " mode=\"markers\", marker={\"color\": \"cyan\", \"size\": 5},\n", " name=\"position\", showlegend=True)\n", "fig.add_trace(trace)\n", diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikeDecoder.ipynb similarity index 85% rename from examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb rename to examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikeDecoder.ipynb index 5dde0a4..6090e25 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikesDecoder.ipynb +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikeDecoder.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Learn" + "# Train the decoder and save the results" ] }, { @@ -14,9 +14,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "import sys\n", "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", "import pickle as pkl\n", "\n", "import replay_trajectory_classification as rtc\n", @@ -162,8 +160,7 @@ "source": [ "print(f\"Saving model to {model_filename}\")\n", "\n", - "results = dict(decoder=decoder, linear_position=linear_position,\n", - " spike_times=binned_spike_times, Fs=Fs)\n", + "results = dict(decoder=decoder)\n", "\n", "with open(model_filename, \"wb\") as f:\n", " pkl.dump(results, f)" @@ -173,7 +170,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Decode" + "# Decode position from spikes" ] }, { @@ -186,21 +183,6 @@ "decoding_duration_secs = 100" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(model_filename, \"rb\") as f:\n", - " model_results = pkl.load(f)\n", - " \n", - "decoder = model_results[\"decoder\"]\n", - "Fs = model_results[\"Fs\"]\n", - "spike_times = model_results[\"spike_times\"]\n", - "linear_position = model_results[\"linear_position\"]" - ] - }, { "cell_type": "code", "execution_count": null, @@ -212,7 +194,7 @@ "decoding_duration_samples = int(decoding_duration_secs * Fs)\n", "time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)\n", "time = np.arange(linear_position.size) / Fs\n", - "decoding_results = decoder.predict(spike_times[time_ind], time=time[time_ind])" + "decoding_results = decoder.predict(binned_spike_times[time_ind], time=time[time_ind])" ] }, { @@ -223,21 +205,21 @@ "source": [ "print(f\"Saving decoding results to {decoding_filename}\")\n", "\n", - "results = dict(decoding_results=decoding_results, time=time[time_ind],\n", + "results = dict(decoding_results=decoding_results,\n", " linear_position=linear_position[time_ind],\n", - " spike_times=spike_times[time_ind])\n", + " spikes=binned_spike_times[time_ind])\n", "\n", "with open(decoding_filename, \"wb\") as f:\n", " pkl.dump(results, f)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import plotly.graph_objects as go" + "## Optional\n", + "\n", + "Plot the decoded results" ] }, { @@ -246,12 +228,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open(decoding_filename, \"rb\") as f:\n", - " load_res = pkl.load(f)\n", - "\n", - "decoding_results = load_res[\"decoding_results\"]\n", - "time = load_res[\"time\"]\n", - "linear_position = load_res[\"linear_position\"]" + "import plotly.graph_objects as go" ] }, { @@ -268,7 +245,7 @@ " zmin=0.00, zmax=0.05, showscale=False)\n", "fig.add_trace(trace)\n", "\n", - "trace = go.Scatter(x=time, y=linear_position,\n", + "trace = go.Scatter(x=time[time_ind], y=linear_position,\n", " mode=\"markers\", marker={\"color\": \"cyan\", \"size\": 5},\n", " name=\"position\", showlegend=True)\n", "fig.add_trace(trace)\n", From b0f1ca8776cec9add123cdb685d1723f421cb652 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 19 Nov 2024 14:23:33 +0000 Subject: [PATCH 34/34] Updated package version correctly --- .../.bonsai/Bonsai.config | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config index 0f6f004..5216e96 100644 --- a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config +++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config @@ -8,9 +8,11 @@ - - - + + + + + @@ -51,6 +53,8 @@ + + @@ -66,9 +70,11 @@ - - - + + + + +