diff --git a/README.md b/README.md
index 092c553..59a475f 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,11 @@ All of the datasets used in these examples can be found by going to: [https://do
ZebrafishExampleVid.avi - provided by Nicholas Guilbeault in the Thiele lab at the University of Toronto. If you would like to refer to this data, please cite Guilbeault, N.C., Guerguiev, J., Martin, M. et al. (2021). BonZeb: open-source, modular software tools for high-resolution zebrafish tracking and analysis. *Scientific Reports* *11*, 8148, [https://doi.org/10.1038/s41598-021-85896-x](https://doi.org/10.1038/s41598-021-85896-x).
-ForagingMouseExampleVid.avi - provided by the Sainsbury Wellcome Centre Foraging Behaviour Working Group. (2023). Aeon: An open-source platform to study the neural basis of ethological behaviours over naturalistic timescales, [https://doi.org/10.5281/zenodo.8413142](https://doi.org/10.5281/zenodo.8413142)
+ForagingMouseExampleVid.avi - provided by the Sainsbury Wellcome Centre Foraging Behaviour Working Group. (2023). Aeon: An open-source platform to study the neural basis of ethological behaviours over naturalistic timescales, [https://doi.org/10.5281/zenodo.8413142](https://doi.org/10.5281/zenodo.8413142).
-ReceptiveFieldSimpleCell.zip - provided by the authors of "Touryan, J., Felsen, G., & Dan, Y. (2005). Spatial structure of complex cell receptive fields measured with natural images. Neuron, 45(5), 781-791." [https://doi.org/10.1016/j.neuron.2005.01.029](https://doi.org/10.1016/j.neuron.2005.01.029)
+ReceptiveFieldSimpleCell.zip - provided by the authors of "Touryan, J., Felsen, G., & Dan, Y. (2005). Spatial structure of complex cell receptive fields measured with natural images. Neuron, 45(5), 781-791." [https://doi.org/10.1016/j.neuron.2005.01.029](https://doi.org/10.1016/j.neuron.2005.01.029).
+
+HippocampalTetrodeRecording.zip - provided by the authors of Joshi, A., Denovellis, E.L., Mankili, A. et al. (2023). Dynamic synchronization between hippocampal representations and stepping. Nature 617, 125–131. [https://doi.org/10.1038/s41586-023-05928-6](https://doi.org/10.1038/s41586-023-05928-6).
### Acknowledgements
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config
new file mode 100644
index 0000000..5216e96
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Bonsai.config
@@ -0,0 +1,116 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/NuGet.config b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/NuGet.config
new file mode 100644
index 0000000..97e8b73
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/NuGet.config
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.ps1 b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.ps1
new file mode 100644
index 0000000..76b5c46
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.ps1
@@ -0,0 +1,21 @@
+Push-Location $PSScriptRoot
+if (!(Test-Path "./Bonsai.exe")) {
+ $release = "https://github.com/bonsai-rx/bonsai/releases/latest/download/Bonsai.zip"
+ $configPath = "./Bonsai.config"
+ if (Test-Path $configPath) {
+ [xml]$config = Get-Content $configPath
+ $bootstrapper = $config.PackageConfiguration.Packages.Package.where{$_.id -eq 'Bonsai'}
+ if ($bootstrapper) {
+ $version = $bootstrapper.version
+ $release = "https://github.com/bonsai-rx/bonsai/releases/download/$version/Bonsai.zip"
+ }
+ }
+ Invoke-WebRequest $release -OutFile "temp.zip"
+ Move-Item -Path "NuGet.config" "temp.config" -ErrorAction SilentlyContinue
+ Expand-Archive "temp.zip" -DestinationPath "." -Force
+ Move-Item -Path "temp.config" "NuGet.config" -Force -ErrorAction SilentlyContinue
+ Remove-Item -Path "temp.zip"
+ Remove-Item -Path "Bonsai32.exe"
+}
+& .\Bonsai.exe --no-editor
+Pop-Location
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.sh b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.sh
new file mode 100644
index 0000000..941d850
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/Setup.sh
@@ -0,0 +1,41 @@
+#! /bin/bash
+
+SETUP_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)"
+
+DEFAULT_VERSION="latest"
+VERSION="$DEFAULT_VERSION"
+
+while [[ "$#" -gt 0 ]]; do
+ case $1 in
+ --version) VERSION="$2"; shift ;;
+ *) echo "Unknown parameter passed: $1"; exit 1 ;;
+ esac
+ shift
+done
+
+echo "Setting up Bonsai v=$VERSION environment..."
+
+if [ ! -f "$SETUP_SCRIPT_DIR/Bonsai.exe" ]; then
+ CONFIG="$SETUP_SCRIPT_DIR/Bonsai.config"
+ if [ -f "$CONFIG" ]; then
+ DETECTED=$(xmllint --xpath '//PackageConfiguration/Packages/Package[@id="Bonsai"]/@version' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//')
+ echo "Version detected v=$DETECTED."
+ RELEASE="https://github.com/bonsai-rx/bonsai/releases/download/$DETECTED/Bonsai.zip"
+ else
+ if [ $VERSION = "latest" ]; then
+ RELEASE="https://github.com/bonsai-rx/bonsai/releases/latest/download/Bonsai.zip"
+ else
+ RELEASE="https://github.com/bonsai-rx/bonsai/releases/download/$VERSION/Bonsai.zip"
+ fi
+ fi
+ echo "Download URL: $RELEASE"
+ wget $RELEASE -O "$SETUP_SCRIPT_DIR/temp.zip"
+ mv -f "$SETUP_SCRIPT_DIR/NuGet.config" "$SETUP_SCRIPT_DIR/temp.config"
+ unzip -d "$SETUP_SCRIPT_DIR" -o "$SETUP_SCRIPT_DIR/temp.zip"
+ mv -f "$SETUP_SCRIPT_DIR/temp.config" "$SETUP_SCRIPT_DIR/NuGet.config"
+ rm -rf "$SETUP_SCRIPT_DIR/temp.zip"
+ rm -rf "$SETUP_SCRIPT_DIR/Bonsai32.exe"
+fi
+
+source "$SETUP_SCRIPT_DIR/activate"
+source "$SETUP_SCRIPT_DIR/run" --no-editor
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/activate b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/activate
new file mode 100644
index 0000000..ddf75f3
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/activate
@@ -0,0 +1,15 @@
+#!/bin/bash
+# activate.sh
+if [[ -v BONSAI_EXE_PATH ]]; then
+ echo "Error! Cannot have multiple bonsai environments activated at the same time. Please deactivate the current environment before activating the new one."
+ return
+fi
+BONSAI_ENV_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)"
+export BONSAI_ENV_DIR
+export BONSAI_EXE_PATH="$BONSAI_ENV_DIR/Bonsai.exe"
+export ORIGINAL_PS1="$PS1"
+export PS1="($(basename "$BONSAI_ENV_DIR")) $PS1"
+alias bonsai='source "$BONSAI_ENV_DIR"/run'
+alias bonsai-clean='GTK_DATA_PREFIX= source "$BONSAI_ENV_DIR"/run'
+alias deactivate='source "$BONSAI_ENV_DIR"/deactivate'
+echo "Activated bonsai environment in $BONSAI_ENV_DIR"
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/deactivate b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/deactivate
new file mode 100644
index 0000000..43233d9
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/deactivate
@@ -0,0 +1,8 @@
+#!/bin/bash
+unset BONSAI_EXE_PATH
+export PS1="$ORIGINAL_PS1"
+unset ORIGINAL_PS1
+unalias bonsai
+unalias bonsai-clean
+unalias deactivate
+echo "Deactivated bonsai environment."
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/run b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/run
new file mode 100644
index 0000000..bffd6cf
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/.bonsai/run
@@ -0,0 +1,58 @@
+#!/bin/bash
+# run.sh
+
+SETUP_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)"
+CONFIG="$SETUP_SCRIPT_DIR/Bonsai.config"
+
+cleanup() {
+ update_paths_to_windows
+}
+
+update_paths_to_linux() {
+ ASSEMBLYLOCATIONS=$(xmllint --xpath '//PackageConfiguration/AssemblyLocations/AssemblyLocation/@location' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//')
+ for ASSEMBLYLOCATION in $ASSEMBLYLOCATIONS; do
+ NEWASSEMBLYLOCATION="${ASSEMBLYLOCATION//\\/\/}"
+ xmlstarlet edit --inplace --update "/PackageConfiguration/AssemblyLocations/AssemblyLocation[@location='$ASSEMBLYLOCATION']/@location" --value "$NEWASSEMBLYLOCATION" "$CONFIG"
+ done
+
+ LIBRARYFOLDERS=$(xmllint --xpath '//PackageConfiguration/LibraryFolders/LibraryFolder/@path' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//')
+ for LIBRARYFOLDER in $LIBRARYFOLDERS; do
+ NEWLIBRARYFOLDER="${LIBRARYFOLDER//\\/\/}"
+ xmlstarlet edit --inplace --update "//PackageConfiguration/LibraryFolders/LibraryFolder[@path='$LIBRARYFOLDER']/@path" --value "$NEWLIBRARYFOLDER" "$CONFIG"
+ done
+}
+
+update_paths_to_windows() {
+ ASSEMBLYLOCATIONS=$(xmllint --xpath '//PackageConfiguration/AssemblyLocations/AssemblyLocation/@location' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//')
+ for ASSEMBLYLOCATION in $ASSEMBLYLOCATIONS; do
+ NEWASSEMBLYLOCATION="${ASSEMBLYLOCATION//\//\\}"
+ xmlstarlet edit --inplace --update "/PackageConfiguration/AssemblyLocations/AssemblyLocation[@location='$ASSEMBLYLOCATION']/@location" --value "$NEWASSEMBLYLOCATION" "$CONFIG"
+ done
+
+ LIBRARYFOLDERS=$(xmllint --xpath '//PackageConfiguration/LibraryFolders/LibraryFolder/@path' "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//')
+ for LIBRARYFOLDER in $LIBRARYFOLDERS; do
+ NEWLIBRARYFOLDER="${LIBRARYFOLDER//\//\\}"
+ xmlstarlet edit --inplace --update "//PackageConfiguration/LibraryFolders/LibraryFolder[@path='$LIBRARYFOLDER']/@path" --value "$NEWLIBRARYFOLDER" "$CONFIG"
+ done
+}
+
+if [[ -v BONSAI_EXE_PATH ]]; then
+ if [ ! -f "$BONSAI_EXE_PATH" ]; then
+ bash "$BONSAI_ENV_DIR"/Setup.sh
+ bash "$BONSAI_ENV_DIR"/run "$@"
+ else
+ BONSAI_VERSION=$(xmllint --xpath "//PackageConfiguration/Packages/Package[@id='Bonsai']/@version" "$CONFIG" | sed -e 's/^[^"]*"//' -e 's/"$//')
+ if [[ -z ${BONSAI_VERSION+x} ]] && [ "$BONSAI_VERSION" \< "2.8.4" ]; then
+ echo "Updating paths to Linux format..."
+ trap cleanup EXIT INT TERM
+ update_paths_to_linux
+ mono "$BONSAI_EXE_PATH" "$@"
+ cleanup
+ else
+ mono "$BONSAI_EXE_PATH" "$@"
+ fi
+ fi
+else
+ echo "BONSAI_EXE_PATH is not set. Please set the path to the Bonsai executable."
+ return
+fi
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai
new file mode 100644
index 0000000..dc0894b
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeClusterlessSpikes.bonsai
@@ -0,0 +1,77 @@
+
+
+
+
+
+
+
+
+
+
+
+ decoder
+ ../../../datasets/decoder_data/clusterless_spike_decoder.pkl
+
+
+ ../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl
+ PT0.1S
+
+
+ Data
+
+
+ Data
+
+
+ Spikes
+
+
+ decoder
+
+
+ OnlineDecodedResults
+
+
+ VisualizeDecoder
+
+
+
+ Data
+
+
+ Position
+
+
+ ConvertNaN
+ it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString())
+
+
+
+ OnlineDecodedResults
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai
new file mode 100644
index 0000000..20e1307
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/DecodeSortedSpikes.bonsai
@@ -0,0 +1,77 @@
+
+
+
+
+
+
+
+
+
+
+
+ decoder
+ ../../../datasets/decoder_data/sorted_spike_decoder.pkl
+
+
+ ../../../datasets/decoder_data/sorted_spike_decoding_results.pkl
+ PT0.01S
+
+
+ Data
+
+
+ Data
+
+
+ Spikes
+
+
+ decoder
+
+
+ OnlineDecodedResults
+
+
+ VisualizeDecoder
+
+
+
+ Data
+
+
+ Position
+
+
+ ConvertNaN
+ it.ToString() == "nan" ? double.NaN : Convert.ToDouble(it.ToString())
+
+
+
+ OnlineDecodedResults
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md
new file mode 100644
index 0000000..30f08dc
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/README.md
@@ -0,0 +1,59 @@
+# Position Decoding from Hippocampal Sorted Spikes
+
+In the following example, you can find how to use the spike sorted decoder or clusterless spike decoder from [here](https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/tree/master?tab=readme-ov-file) to decode position from hippocampal activity.
+
+## Dataset
+
+We thank Eric Denovellis for sharing his data and for his help with the decoder. If you use this example dataset, please consider citing the work: Joshi, A., Denovellis, E.L., Mankili, A. et al. Dynamic synchronization between hippocampal representations and stepping. Nature 617, 125–131 (2023). https://doi.org/10.1038/s41586-023-05928-6.
+
+You can download the HippocampalTetrodeRecordings.zip file here: https://doi.org/10.5281/zenodo.10629221. The workflow expects the data to be placed into the datasets folder. It should be structured like this: `/path/to/machinelearning-examples/datasets/decoder_data/*`
+
+## Algorithm
+
+The neural decoder consists of a bayesian state-space model and point-processes to decode a latent variable (position) from neural spiking activity. To read more about the theory behind the model and how the algorithm works, we refer the reader to: Denovellis, E.L., Gillespie, A.K., Coulter, M.E., et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021). https://doi.org/10.7554/eLife.64505.
+
+## Installation
+
+### Python
+
+To install the package, run:
+
+```
+cd \path\to\examples\NeuralDecoding\PositionDecodingFromHippocampus
+python -m venv .venv
+.\.venv\Scripts\activate
+pip install git+https://github.com/ncguilbeault/bayesian-neural-decoder.git
+```
+
+You can test whether the installation was successful by launching python and running
+
+```python
+import bayesian_neural_decoder
+```
+
+### Bonsai
+
+You can bootstrap the bonsai environment using:
+
+```
+cd \path\to\examples\NeuralDecoding\PositionDecodingFromHippocampus
+dotnet new bonsaienv --allow-scripts yes
+```
+
+Alternatively, you can copy the `.bonsai\Bonsai.config` file into your Bonsai installation folder. You can test if it worked by openning bonsai and searching for the `CreateRuntime` node, which should appear in the toolbox.
+
+## Usage
+
+### Training the decoder offline
+
+The package contains 2 different models, one which takes as input sorted spike activity, and another which uses clusterless spike activity taken from raw ephys recordings.
+
+You first need to train the decoder model and save it to disk. Open up the `notebooks` folder and select either `SortedSpikeDecoder.ipynb` or `ClusterlessSpikeDecoder.ipynb` depending on which model type you would like to use. Run the notebook. Once completed, this will create 2 new files:
+1) `datasets\decoder_data\[ModelType]_decoder.pkl` for the trained decoder model
+2) `datasets\decoder_data\[ModelType]_decoding_results.pkl` for the predictions.
+
+Both of these files are needed to be run the decoder example in Bonsai.
+
+### Running the decoder online with Bonsai
+
+Launch the Bonsai.exe file inside of the .bonsai folder and open the workflow corresponding to the model type you used in the previous step. Press the `Start Workflow` button. The workflow may take some time to initialize and load the data. Once the workflow is running, open the `VisualizeDecoder` node to see the model's inference online with respect to the true position.
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessSpikeDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessSpikeDecoder.ipynb
new file mode 100644
index 0000000..2561f7a
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/ClusterlessSpikeDecoder.ipynb
@@ -0,0 +1,293 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Train the decoder and save the results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import pickle as pkl\n",
+ "\n",
+ "import replay_trajectory_classification as rtc\n",
+ "import track_linearization as tl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n",
+ "spikes_filename = \"../../../../datasets/decoder_data/clusterless_spike_times.pkl\"\n",
+ "features_filename = \"../../../../datasets/decoder_data/clusterless_spike_features.pkl\"\n",
+ "model_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoder.pkl\"\n",
+ "decoding_filename = \"../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "positions_df = pd.read_pickle(positions_filename)\n",
+ "timestamps = positions_df.index.to_numpy()\n",
+ "time_start = timestamps[0]\n",
+ "time_end = timestamps[-1]\n",
+ "dt = 0.02\n",
+ "Fs = 1.0 / dt\n",
+ "spikes_bins = np.arange(time_start - dt, time_end + dt, dt)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "positions_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = positions_df[\"nose_x\"].to_numpy()\n",
+ "y = positions_df[\"nose_y\"].to_numpy()\n",
+ "positions = np.column_stack((x, y))\n",
+ "node_positions = [(120.0, 100.0),\n",
+ " ( 5.0, 100.0),\n",
+ " ( 5.0, 55.0),\n",
+ " (120.0, 55.0),\n",
+ " ( 5.0, 8.5),\n",
+ " (120.0, 8.5),\n",
+ " ]\n",
+ "edges = [\n",
+ " (3, 2),\n",
+ " (0, 1),\n",
+ " (1, 2),\n",
+ " (5, 4),\n",
+ " (4, 2),\n",
+ " ]\n",
+ "track_graph = rtc.make_track_graph(node_positions, edges)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "edge_order = [\n",
+ " (3, 2),\n",
+ " (0, 1),\n",
+ " (1, 2),\n",
+ " (5, 4),\n",
+ " (4, 2),\n",
+ " ]\n",
+ "\n",
+ "edge_spacing = [16, 0, 16, 0]\n",
+ "\n",
+ "linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(features_filename, \"rb\") as f:\n",
+ " clusterless_spike_features = pkl.load(f)\n",
+ "\n",
+ "with open(spikes_filename, \"rb\") as f:\n",
+ " clusterless_spike_times = pkl.load(f)\n",
+ "\n",
+ "features = np.ones((len(spikes_bins) - 1, len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan\n",
+ "for n in range(len(clusterless_spike_times)):\n",
+ " in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins) - 1\n",
+ " features[in_spikes_window, :, n] = clusterless_spike_features[n]\n",
+ "\n",
+ "linear_position = np.ones(len(spikes_bins) - 1) * np.nan\n",
+ "in_position_window = np.digitize(positions_df.index, spikes_bins) - 1\n",
+ "linear_position[in_position_window] = linearized_positions.linear_position"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "place_bin_size = 0.5\n",
+ "movement_var = 0.25\n",
+ "\n",
+ "environment = rtc.Environment(place_bin_size=place_bin_size,\n",
+ " track_graph=track_graph,\n",
+ " edge_order=edge_order,\n",
+ " edge_spacing=edge_spacing)\n",
+ "\n",
+ "transition_type = rtc.RandomWalk(movement_var=movement_var)\n",
+ "\n",
+ "decoder = rtc.ClusterlessDecoder(\n",
+ " environment=environment,\n",
+ " transition_type=transition_type,\n",
+ " clusterless_algorithm=\"multiunit_likelihood_integer_gpu\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Learning model parameters\")\n",
+ "decoder.fit(linear_position, features)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"Saving model to {model_filename}\")\n",
+ "\n",
+ "results = dict(decoder=decoder)\n",
+ "\n",
+ "with open(model_filename, \"wb\") as f:\n",
+ " pkl.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Decode"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoding_start_secs = 0\n",
+ "decoding_duration_secs = 100"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Decoding positions from features\")\n",
+ "decoding_start_samples = int(decoding_start_secs * Fs)\n",
+ "decoding_duration_samples = int(decoding_duration_secs * Fs)\n",
+ "time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)\n",
+ "time = np.arange(linear_position.size) / Fs\n",
+ "decoding_results = decoder.predict(features[time_ind], time=time[time_ind])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"Saving decoded results to {decoding_filename}\")\n",
+ "\n",
+ "results = dict(decoding_results=decoding_results,\n",
+ " linear_position=linear_position[time_ind],\n",
+ " spikes=features[time_ind])\n",
+ "\n",
+ "with open(decoding_filename, \"wb\") as f:\n",
+ " pkl.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Optional\n",
+ "\n",
+ "Plot the decoded results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import plotly.graph_objects as go"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = go.Figure()\n",
+ "\n",
+ "trace = go.Heatmap(z=decoding_results.acausal_posterior.T,\n",
+ " x=decoding_results.acausal_posterior.time,\n",
+ " y=decoding_results.acausal_posterior.position,\n",
+ " zmin=0.00, zmax=0.05, showscale=False)\n",
+ "fig.add_trace(trace)\n",
+ "\n",
+ "trace = go.Scatter(x=time[time_ind], y=linear_position,\n",
+ " mode=\"markers\", marker={\"color\": \"cyan\", \"size\": 5},\n",
+ " name=\"position\", showlegend=True)\n",
+ "fig.add_trace(trace)\n",
+ "\n",
+ "fig.update_xaxes(title=\"Time (sec)\")\n",
+ "fig.update_yaxes(title=\"Position (cm)\")\n",
+ "fig.update_coloraxes(showscale=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikeDecoder.ipynb b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikeDecoder.ipynb
new file mode 100644
index 0000000..6090e25
--- /dev/null
+++ b/examples/NeuralDecoding/PositionDecodingFromHippocampus/notebooks/SortedSpikeDecoder.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Train the decoder and save the results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import pickle as pkl\n",
+ "\n",
+ "import replay_trajectory_classification as rtc\n",
+ "import track_linearization as tl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "positions_filename = \"../../../../datasets/decoder_data/position_info.pkl\"\n",
+ "spikes_filename = \"../../../../datasets/decoder_data/sorted_spike_times.pkl\"\n",
+ "model_filename = \"../../../../datasets/decoder_data/sorted_spike_decoder.pkl\"\n",
+ "decoding_filename = \"../../../../datasets/decoder_data/sorted_spike_decoding_results.pkl\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "positions_df = pd.read_pickle(positions_filename)\n",
+ "timestamps = positions_df.index.to_numpy()\n",
+ "time_start = timestamps[0]\n",
+ "time_end = timestamps[-1]\n",
+ "dt = timestamps[1] - timestamps[0]\n",
+ "Fs = 1.0 / dt\n",
+ "spikes_bins = np.arange(time_start - dt, time_end + dt, dt)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "positions_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = positions_df[\"nose_x\"].to_numpy()\n",
+ "y = positions_df[\"nose_y\"].to_numpy()\n",
+ "positions = np.column_stack((x, y))\n",
+ "node_positions = [(120.0, 100.0),\n",
+ " ( 5.0, 100.0),\n",
+ " ( 5.0, 55.0),\n",
+ " (120.0, 55.0),\n",
+ " ( 5.0, 8.5),\n",
+ " (120.0, 8.5),\n",
+ " ]\n",
+ "edges = [\n",
+ " (3, 2),\n",
+ " (0, 1),\n",
+ " (1, 2),\n",
+ " (5, 4),\n",
+ " (4, 2),\n",
+ " ]\n",
+ "track_graph = rtc.make_track_graph(node_positions, edges)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "edge_order = [\n",
+ " (3, 2),\n",
+ " (0, 1),\n",
+ " (1, 2),\n",
+ " (5, 4),\n",
+ " (4, 2),\n",
+ " ]\n",
+ "\n",
+ "edge_spacing = [16, 0, 16, 0]\n",
+ "\n",
+ "linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(spikes_filename, \"rb\") as f:\n",
+ " sorted_spike_times = pkl.load(f)\n",
+ "\n",
+ "binned_spike_times = np.empty((len(spikes_bins) - 1, len(sorted_spike_times)), dtype=float)\n",
+ "for n in range(len(sorted_spike_times)):\n",
+ " binned_spike_times[:, n] = np.histogram(sorted_spike_times[n], spikes_bins)[0]\n",
+ "\n",
+ "linear_position = np.ones(len(spikes_bins) - 1) * np.nan\n",
+ "in_position_window = np.digitize(positions_df.index, spikes_bins) - 1\n",
+ "linear_position[in_position_window] = linearized_positions.linear_position"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "place_bin_size = 0.5\n",
+ "movement_var = 0.25\n",
+ "\n",
+ "environment = rtc.Environment(place_bin_size=place_bin_size,\n",
+ " track_graph=track_graph,\n",
+ " edge_order=edge_order,\n",
+ " edge_spacing=edge_spacing)\n",
+ "\n",
+ "transition_type = rtc.RandomWalk(movement_var=movement_var)\n",
+ "\n",
+ "decoder = rtc.SortedSpikesDecoder(\n",
+ " environment=environment,\n",
+ " transition_type=transition_type,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Learning model parameters\")\n",
+ "decoder.fit(linear_position, binned_spike_times)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"Saving model to {model_filename}\")\n",
+ "\n",
+ "results = dict(decoder=decoder)\n",
+ "\n",
+ "with open(model_filename, \"wb\") as f:\n",
+ " pkl.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Decode position from spikes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoding_start_secs = 0\n",
+ "decoding_duration_secs = 100"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Decoding positions from spikes\")\n",
+ "decoding_start_samples = int(decoding_start_secs * Fs)\n",
+ "decoding_duration_samples = int(decoding_duration_secs * Fs)\n",
+ "time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)\n",
+ "time = np.arange(linear_position.size) / Fs\n",
+ "decoding_results = decoder.predict(binned_spike_times[time_ind], time=time[time_ind])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"Saving decoding results to {decoding_filename}\")\n",
+ "\n",
+ "results = dict(decoding_results=decoding_results,\n",
+ " linear_position=linear_position[time_ind],\n",
+ " spikes=binned_spike_times[time_ind])\n",
+ "\n",
+ "with open(decoding_filename, \"wb\") as f:\n",
+ " pkl.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Optional\n",
+ "\n",
+ "Plot the decoded results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import plotly.graph_objects as go"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = go.Figure()\n",
+ "\n",
+ "trace = go.Heatmap(z=decoding_results.acausal_posterior.T,\n",
+ " x=decoding_results.acausal_posterior.time,\n",
+ " y=decoding_results.acausal_posterior.position,\n",
+ " zmin=0.00, zmax=0.05, showscale=False)\n",
+ "fig.add_trace(trace)\n",
+ "\n",
+ "trace = go.Scatter(x=time[time_ind], y=linear_position,\n",
+ " mode=\"markers\", marker={\"color\": \"cyan\", \"size\": 5},\n",
+ " name=\"position\", showlegend=True)\n",
+ "fig.add_trace(trace)\n",
+ "\n",
+ "fig.update_xaxes(title=\"Time (sec)\")\n",
+ "fig.update_yaxes(title=\"Position (cm)\")\n",
+ "fig.update_coloraxes(showscale=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}