diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index e5e6a07a..a463737f 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -1,7 +1,11 @@
name: Test
on:
push:
+ branches:
+ - main
pull_request:
+ branches:
+ - main
workflow_dispatch:
jobs:
devcontainer-build:
@@ -11,12 +15,7 @@ jobs:
strategy:
matrix:
py_ver: ["3.9", "3.10"]
- mysql_ver: ["8.0", "5.7"]
- include:
- - py_ver: "3.8"
- mysql_ver: "5.7"
- - py_ver: "3.7"
- mysql_ver: "5.7"
+ mysql_ver: ["8.0"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{matrix.py_ver}}
@@ -31,4 +30,3 @@ jobs:
run: |
python_version=${{matrix.py_ver}}
black element_array_ephys --check --verbose --target-version py${python_version//.}
-
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0d513df7..6d28ef11 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,7 +3,7 @@ exclude: (^.github/|^docs/|^images/)
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.4.0
+ rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@@ -16,7 +16,7 @@ repos:
# black
- repo: https://github.com/psf/black
- rev: 22.12.0
+ rev: 24.2.0
hooks:
- id: black
- id: black-jupyter
@@ -25,7 +25,7 @@ repos:
# isort
- repo: https://github.com/pycqa/isort
- rev: 5.11.2
+ rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]
@@ -33,7 +33,7 @@ repos:
# flake8
- repo: https://github.com/pycqa/flake8
- rev: 4.0.1
+ rev: 7.0.0
hooks:
- id: flake8
args: # arguments to configure flake8
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 810a1ca1..6d136a29 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,14 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
+## [0.4.0] - 2025-01-28
+
++ Update - No longer support multiple variation of ephys module, keep only `ephys_no_curation` module, renamed to `ephys`
++ Update - Remove other ephys modules (e.g. `ephys_acute`, `ephys_chronic`) (moved to different branches)
++ Feat - Add support for `SpikeInterface`
++ Update - Remove support for `ecephys_spike_sorting` (moved to a different branch)
++ Update - Simplify the "activate" mechanism
+
## [0.3.8] - 2025-01-16
* Fix - Revert GHA Semantic Release caller and update changelog
diff --git a/docs/docker-compose.yaml b/docs/docker-compose.yaml
index 5ba221df..bc2c2b8b 100644
--- a/docs/docker-compose.yaml
+++ b/docs/docker-compose.yaml
@@ -30,12 +30,6 @@ services:
export ELEMENT_UNDERSCORE=$$(echo $${PACKAGE} | sed 's/element_//g')
export ELEMENT_HYPHEN=$$(echo $${ELEMENT_UNDERSCORE} | sed 's/_/-/g')
export PATCH_VERSION=$$(cat /main/$${PACKAGE}/version.py | grep -oE '\d+\.\d+\.[a-z0-9]+')
- git clone https://github.com/datajoint/workflow-$${ELEMENT_HYPHEN}.git /main/delete || true
- if [ -d /main/delete/ ]; then
- mv /main/delete/workflow_$${ELEMENT_UNDERSCORE} /main/
- mv /main/delete/notebooks/*ipynb /main/docs/src/tutorials/
- rm -fR /main/delete
- fi
if echo "$${MODE}" | grep -i live &>/dev/null; then
mkdocs serve --config-file ./docs/mkdocs.yaml -a 0.0.0.0:80 2>&1 | tee docs/temp_mkdocs.log
elif echo "$${MODE}" | grep -iE "qa|push" &>/dev/null; then
diff --git a/docs/mkdocs.yaml b/docs/mkdocs.yaml
index 5fdbffd2..e211069a 100644
--- a/docs/mkdocs.yaml
+++ b/docs/mkdocs.yaml
@@ -9,18 +9,7 @@ nav:
- Concepts: concepts.md
- Tutorials:
- Overview: tutorials/index.md
- - Data Download: tutorials/00-data-download-optional.ipynb
- - Configure: tutorials/01-configure.ipynb
- - Workflow Structure: tutorials/02-workflow-structure-optional.ipynb
- - Process: tutorials/03-process.ipynb
- - Automate: tutorials/04-automate-optional.ipynb
- - Explore: tutorials/05-explore.ipynb
- - Drop: tutorials/06-drop-optional.ipynb
- - Downstream Analysis: tutorials/07-downstream-analysis.ipynb
- - Visualizations: tutorials/10-data_visualization.ipynb
- - Electrode Localization: tutorials/08-electrode-localization.ipynb
- - NWB Export: tutorials/09-NWB-export.ipynb
- - Quality Metrics: tutorials/quality_metrics.ipynb
+ - Tutorial: tutorials/tutorial.ipynb
- Citation: citation.md
- API: api/ # defer to gen-files + literate-nav
- Changelog: changelog.md
diff --git a/docs/src/concepts.md b/docs/src/concepts.md
index f864b306..b5da5081 100644
--- a/docs/src/concepts.md
+++ b/docs/src/concepts.md
@@ -59,12 +59,16 @@ significant community uptake:
Kilosort provides most automation and has gained significant popularity, being adopted
as one of the key spike sorting methods in the majority of the teams/collaborations we
have worked with. As part of our Year-1 NIH U24 effort, we provide support for data
-ingestion of spike sorting results from Kilosort. Further effort will be devoted for the
+ingestion of spike sorting results from Kilosort.
+
+Further effort has been devoted for the
ingestion support of other spike sorting methods. On this end, a framework for unifying
existing spike sorting methods, named
[SpikeInterface](https://github.com/SpikeInterface/spikeinterface), has been developed
by Alessio Buccino, et al. SpikeInterface provides a convenient Python-based wrapper to
-invoke, extract, compare spike sorting results from different sorting algorithms.
+invoke, extract, compare spike sorting results from different sorting algorithms.
+SpikeInterface is the primary tool supported by Element Array Electrophysiology for
+spike sorting as of version `0.4.0`.
## Key Partnerships
@@ -95,22 +99,10 @@ Each of the DataJoint Elements creates a set of tables for common neuroscience d
modalities to organize, preprocess, and analyze data. Each node in the following diagram
is a table within the Element or a table connected to the Element.
-### `ephys_acute` module
+### `ephys` module

-### `ephys_chronic` module
-
-
-
-### `ephys_precluster` module
-
-
-
-### `ephys_no_curation` module
-
-
-
### `subject` schema ([API docs](https://datajoint.com/docs/elements/element-animal/api/element_animal/subject))
Although not required, most choose to connect the `Session` table to a `Subject` table.
@@ -181,12 +173,11 @@ Major features of the Array Electrophysiology Element include:
+ Probe-insertion, ephys-recordings, LFP extraction, clusterings, curations, sorted
units and the associated data (e.g. spikes, waveforms, etc.).
- + Store/track/manage different curations of the spike sorting results - supporting
- both curated clustering and kilosort triggered clustering (i.e., `no_curation`).
+ + Store/track/manage the spike sorting results.
+ Ingestion support for data acquired with SpikeGLX and OpenEphys acquisition systems.
-+ Ingestion support for spike sorting outputs from Kilosort.
-+ Triggering support for workflow integrated Kilosort processing.
++ Ingestion support for spike sorting outputs from SpikeInterface.
++ Triggering support for workflow integrated SpikeInterface processing.
+ Sample data and complete test suite for quality assurance.
## Data Export and Publishing
@@ -208,8 +199,7 @@ pip install element-array-ephys[nwb]
## Roadmap
-Incorporation of SpikeInterface into the Array Electrophysiology Element will be
-on DataJoint Elements development roadmap. Dr. Loren Frank has led a development
+Dr. Loren Frank has led a development
effort of a DataJoint pipeline with SpikeInterface framework and
NeurodataWithoutBorders format integrated
[https://github.com/LorenFrankLab/nwb_datajoint](https://github.com/LorenFrankLab/nwb_datajoint).
diff --git a/docs/src/index.md b/docs/src/index.md
index b21edcfc..5d9b7f19 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -1,29 +1,23 @@
# Element Array Electrophysiology
This Element features DataJoint schemas for analyzing extracellular array
-electrophysiology data acquired with Neuropixels probes and spike sorted using Kilosort
-spike sorter. Each Element is a modular pipeline for data storage and processing with
+electrophysiology data acquired with Neuropixels probes and spike sorted using [SpikeInterface](https://github.com/SpikeInterface/spikeinterface).
+Each Element is a modular pipeline for data storage and processing with
corresponding database tables that can be combined with other Elements to assemble a
fully functional pipeline.

-The Element is comprised of `probe` and `ephys` schemas. Several `ephys` schemas are
-developed to handle various use cases of this pipeline and workflow:
-
-+ `ephys_acute`: A probe is inserted into a new location during each session.
-
-+ `ephys_chronic`: A probe is inserted once and used to record across multiple
- sessions.
-
-+ `ephys_precluster`: A probe is inserted into a new location during each session.
- Pre-clustering steps are performed on the data from each probe prior to Kilosort
- analysis.
-
-+ `ephys_no_curation`: A probe is inserted into a new location during each session and
- Kilosort-triggered clustering is performed without the option to manually curate the
- results.
-
-Visit the [Concepts page](./concepts.md) for more information about the use cases of
+The Element is comprised of `probe` and `ephys` schemas. Visit the
+[Concepts page](./concepts.md) for more information about the `probe` and
`ephys` schemas and an explanation of the tables. To get started with building your own
data pipeline, visit the [Tutorials page](./tutorials/index.md).
+
+Prior to version `0.4.0` , several `ephys` schemas were
+developed and supported to handle various use cases of this pipeline and workflow. These
+ are now deprecated but still available on their own branch within the repository:
+
+* [`ephys_acute`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_acute)
+* [`ephys_chronic`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_chronic)
+* [`ephys_precluster`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_precluster)
+* [`ephys_no_curation`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_no_curation)
diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md
index 5f367cd9..ff0bd1f5 100644
--- a/docs/src/tutorials/index.md
+++ b/docs/src/tutorials/index.md
@@ -1,14 +1,18 @@
# Tutorials
+## Executing the Tutorial Notebooks
+
+The tutorials are set up to run using GitHub Codespaces. To run the tutorials, click on
+the "Open in Codespaces" button from the GitHub repository. This will open a
+pre-configured environment with a VSCode IDE in your browser. THe environment contains
+all the necessary dependencies and sample data to run the tutorials.
+
## Installation
Installation of the Element requires an integrated development environment and database.
Instructions to setup each of the components can be found on the
-[User Instructions](https://datajoint.com/docs/elements/user-guide/) page. These
-instructions use the example
-[workflow for Element Array Ephys](https://github.com/datajoint/workflow-array-ephys),
-which can be modified for a user's specific experimental requirements. This example
-workflow uses several Elements (Lab, Animal, Session, Event, and Electrophysiology) to construct
+[User Instructions](https://datajoint.com/docs/elements/user-guide/) page. The example
+tutorial uses several Elements (Lab, Animal, Session, Event, and Electrophysiology) to construct
a complete pipeline, and is able to ingest experimental metadata and run model training
and inference.
@@ -23,32 +27,10 @@ Electrophysiology.
### Notebooks
Each of the notebooks in the workflow
-([download here](https://github.com/datajoint/workflow-array-ephys/tree/main/notebooks)
+([download here](https://github.com/datajoint/workflow-array-ephys/tree/main/notebooks))
steps through ways to interact with the Element itself. For convenience, these notebooks
are also rendered as part of this site. To try out the Elements notebooks in an online
Jupyter environment with access to example data, visit
[CodeBook](https://codebook.datajoint.io/). (Electrophysiology notebooks coming soon!)
-- [Data Download](./00-data-download-optional.ipynb) highlights how to use DataJoint
- tools to download a sample model for trying out the Element.
-- [Configure](./01-configure.ipynb) helps configure your local DataJoint installation to
- point to the correct database.
-- [Workflow Structure](./02-workflow-structure-optional.ipynb) demonstrates the table
- architecture of the Element and key DataJoint basics for interacting with these
- tables.
-- [Process](./03-process.ipynb) steps through adding data to these tables and launching
- key Electrophysiology features, like model training.
-- [Automate](./04-automate-optional.ipynb) highlights the same steps as above, but
- utilizing all built-in automation tools.
-- [Explore](./05-explore.ipynb) demonstrates how to fetch data from the Element.
-- [Drop schemas](./06-drop-optional.ipynb) provides the steps for dropping all the
- tables to start fresh.
-- [Downstream Analysis](./07-downstream-analysis.ipynb) highlights how to link
- this Element to Element Event for event-based analyses.
-- [Visualizations](./10-data_visualization.ipynb) highlights how to use a built-in module
- for visualizing units, probes and quality metrics.
-- [Electrode Localization](./08-electrode-localization.ipynb) demonstrates how to link
- this Element to
- [Element Electrode Localization](https://datajoint.com/docs/elements/element-electrode-localization/).
-- [NWB Export](./09-NWB-export.ipynb) highlights the export functionality available for the
- `no-curation` schema.
+- [Tutorial](../../../notebooks/tutorial.ipynb)
diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py
index 1c0c7285..079950b4 100644
--- a/element_array_ephys/__init__.py
+++ b/element_array_ephys/__init__.py
@@ -1 +1,3 @@
-from . import ephys_acute as ephys
+from . import ephys
+
+ephys_no_curation = ephys # alias for backward compatibility
diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys.py
similarity index 65%
rename from element_array_ephys/ephys_no_curation.py
rename to element_array_ephys/ephys.py
index 1265445a..efa377bf 100644
--- a/element_array_ephys/ephys_no_curation.py
+++ b/element_array_ephys/ephys.py
@@ -10,10 +10,10 @@
import pandas as pd
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-from . import ephys_report, probe
+from . import probe
from .readers import kilosort, openephys, spikeglx
-log = dj.logger
+logger = dj.logger
schema = dj.schema()
@@ -22,7 +22,6 @@
def activate(
ephys_schema_name: str,
- probe_schema_name: str = None,
*,
create_schema: bool = True,
create_tables: bool = True,
@@ -32,7 +31,6 @@ def activate(
Args:
ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
create_schema (bool): If True, schema will be created in the database.
create_tables (bool): If True, tables related to the schema will be created in the database.
linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
@@ -46,7 +44,6 @@ def activate(
get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
-
"""
if isinstance(linking_module, str):
@@ -58,16 +55,15 @@ def activate(
global _linking_module
_linking_module = linking_module
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
+ if not probe.schema.is_activated():
+ raise RuntimeError("Please activate the `probe` schema first.")
+
schema.activate(
ephys_schema_name,
create_schema=create_schema,
create_tables=create_tables,
add_objects=_linking_module.__dict__,
)
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
# -------------- Functions required by the elements-ephys ---------------
@@ -128,7 +124,7 @@ class AcquisitionSoftware(dj.Lookup):
"""
definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys
- acq_software: varchar(24)
+ acq_software: varchar(24)
"""
contents = zip(["SpikeGLX", "Open Ephys"])
@@ -181,10 +177,7 @@ def auto_generate_entries(cls, session_key):
"probe_type": spikeglx_meta.probe_model,
"probe": spikeglx_meta.probe_SN,
}
- if (
- probe_key["probe"] not in [p["probe"] for p in probe_list]
- and probe_key not in probe.Probe()
- ):
+ if probe_key["probe"] not in [p["probe"] for p in probe_list]:
probe_list.append(probe_key)
probe_dir = meta_filepath.parent
@@ -208,10 +201,7 @@ def auto_generate_entries(cls, session_key):
"probe_type": oe_probe.probe_model,
"probe": oe_probe.probe_SN,
}
- if (
- probe_key["probe"] not in [p["probe"] for p in probe_list]
- and probe_key not in probe.Probe()
- ):
+ if probe_key["probe"] not in [p["probe"] for p in probe_list]:
probe_list.append(probe_key)
probe_insertion_list.append(
{
@@ -271,15 +261,24 @@ class EphysRecording(dj.Imported):
definition = """
# Ephys recording from a probe insertion for a given session.
- -> ProbeInsertion
+ -> ProbeInsertion
---
-> probe.ElectrodeConfig
-> AcquisitionSoftware
- sampling_rate: float # (Hz)
+ sampling_rate: float # (Hz)
recording_datetime: datetime # datetime of the recording from this probe
recording_duration: float # (seconds) duration of the recording from this probe
"""
+ class Channel(dj.Part):
+ definition = """
+ -> master
+ channel_idx: int # channel index (index of the raw data)
+ ---
+ -> probe.ElectrodeConfig.Electrode
+ channel_name="": varchar(64) # alias of the channel
+ """
+
class EphysFile(dj.Part):
"""Paths of electrophysiology recording files for each insertion.
@@ -303,7 +302,7 @@ def make(self, key):
"probe"
)
- # search session dir and determine acquisition software
+ # Search session dir and determine acquisition software
for ephys_pattern, ephys_acq_type in (
("*.ap.meta", "SpikeGLX"),
("*.oebin", "Open Ephys"),
@@ -314,9 +313,13 @@ def make(self, key):
break
else:
raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
+ f"Ephys recording data not found in {session_dir}."
+ "Neither SpikeGLX nor Open Ephys recording files found"
+ )
+
+ if acq_software not in AcquisitionSoftware.fetch("acq_software"):
+ raise NotImplementedError(
+ f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented."
)
supported_probe_types = probe.ProbeType.fetch("probe_type")
@@ -325,51 +328,79 @@ def make(self, key):
for meta_filepath in ephys_meta_filepaths:
spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
+ spikeglx_meta_filepath = meta_filepath
break
else:
raise FileNotFoundError(
"No SpikeGLX data found for probe insertion: {}".format(key)
)
- if spikeglx_meta.probe_model in supported_probe_types:
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
+ if spikeglx_meta.probe_model not in supported_probe_types:
+ raise NotImplementedError(
+ f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented."
+ )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
+ probe_type = spikeglx_meta.probe_model
+ electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
+ probe_electrodes = {
+ (shank, shank_col, shank_row): key
+ for key, shank, shank_col, shank_row in zip(
+ *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
)
+ } # electrode configuration
+ electrode_group_members = [
+ probe_electrodes[(shank, shank_col, shank_row)]
+ for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
+ ] # recording session-specific electrode configuration
+
+ econfig_entry, econfig_electrodes = generate_electrode_config_entry(
+ probe_type, electrode_group_members
+ )
- self.insert1(
+ ephys_recording_entry = {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "acq_software": acq_software,
+ "sampling_rate": spikeglx_meta.meta["imSampRate"],
+ "recording_datetime": spikeglx_meta.recording_time,
+ "recording_duration": (
+ spikeglx_meta.recording_duration
+ or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath)
+ ),
+ }
+
+ root_dir = find_root_directory(
+ get_ephys_root_data_dir(), spikeglx_meta_filepath
+ )
+
+ ephys_file_entries = [
{
**key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
+ "file_path": spikeglx_meta_filepath.relative_to(
+ root_dir
+ ).as_posix(),
}
- )
+ ]
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
+ # Insert channel information
+ # Get channel and electrode-site mapping
+ channel2electrode_map = {
+ recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
+ for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
+ spikeglx_meta.shankmap["data"]
+ )
+ }
+
+ ephys_channel_entries = [
+ {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "channel_idx": channel_idx,
+ **channel_info,
+ }
+ for channel_idx, channel_info in channel2electrode_map.items()
+ ]
elif acq_software == "Open Ephys":
dataset = openephys.OpenEphys(session_dir)
for serial_number, probe_data in dataset.probes.items():
@@ -385,60 +416,84 @@ def make(self, key):
'No analog signals found - check "structure.oebin" file or "continuous" directory'
)
- if probe_data.probe_model in supported_probe_types:
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_indices"]
- ]
- else:
+ if probe_data.probe_model not in supported_probe_types:
raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
+ f"Processing for neuropixels probe model {probe_data.probe_model} not yet implemented."
)
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
+ probe_type = probe_data.probe_model
+ electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
+
+ probe_electrodes = {
+ key["electrode"]: key for key in electrode_query.fetch("KEY")
+ } # electrode configuration
+
+ electrode_group_members = [
+ probe_electrodes[channel_idx]
+ for channel_idx in probe_data.ap_meta["channels_indices"]
+ ] # recording session-specific electrode configuration
+
+ econfig_entry, econfig_electrodes = generate_electrode_config_entry(
+ probe_type, electrode_group_members
)
+ ephys_recording_entry = {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "acq_software": acq_software,
+ "sampling_rate": probe_data.ap_meta["sample_rate"],
+ "recording_datetime": probe_data.recording_info["recording_datetimes"][
+ 0
+ ],
+ "recording_duration": np.sum(
+ probe_data.recording_info["recording_durations"]
+ ),
+ }
+
root_dir = find_root_directory(
get_ephys_root_data_dir(),
probe_data.recording_info["recording_files"][0],
)
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- # explicitly garbage collect "dataset"
- # as these may have large memory footprint and may not be cleared fast enough
+
+ ephys_file_entries = [
+ {**key, "file_path": fp.relative_to(root_dir).as_posix()}
+ for fp in probe_data.recording_info["recording_files"]
+ ]
+
+ channel2electrode_map = {
+ channel_idx: probe_electrodes[channel_idx]
+ for channel_idx in probe_data.ap_meta["channels_indices"]
+ }
+
+ ephys_channel_entries = [
+ {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "channel_idx": channel_idx,
+ **channel_info,
+ }
+ for channel_idx, channel_info in channel2electrode_map.items()
+ ]
+
+ # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough
del probe_data, dataset
gc.collect()
else:
raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
+ f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented."
)
+ # Insert into probe.ElectrodeConfig (recording configuration)
+ if not probe.ElectrodeConfig & {
+ "electrode_config_hash": econfig_entry["electrode_config_hash"]
+ }:
+ probe.ElectrodeConfig.insert1(econfig_entry)
+ probe.ElectrodeConfig.Electrode.insert(econfig_electrodes)
+
+ self.insert1(ephys_recording_entry)
+ self.EphysFile.insert(ephys_file_entries)
+ self.Channel.insert(ephys_channel_entries)
+
@schema
class LFP(dj.Imported):
@@ -471,9 +526,9 @@ class Electrode(dj.Part):
definition = """
-> master
- -> probe.ElectrodeConfig.Electrode
+ -> probe.ElectrodeConfig.Electrode
---
- lfp: longblob # (uV) recorded lfp at this electrode
+ lfp: longblob # (uV) recorded lfp at this electrode
"""
# Only store LFP for every 9th channel, due to high channel density,
@@ -614,14 +669,14 @@ class ClusteringParamSet(dj.Lookup):
ClusteringMethod (dict): ClusteringMethod primary key.
paramset_desc (varchar(128) ): Description of the clustering parameter set.
param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Set of clustering parameters
+ params (longblob): Set of clustering parameters.
"""
definition = """
# Parameter set to be used in a clustering procedure
paramset_idx: smallint
---
- -> ClusteringMethod
+ -> ClusteringMethod
paramset_desc: varchar(128)
param_set_hash: uuid
unique index (param_set_hash)
@@ -700,6 +755,7 @@ class ClusterQualityLabel(dj.Lookup):
("ok", "probably a single unit, but could be contaminated"),
("mua", "multi-unit activity"),
("noise", "bad unit"),
+ ("n.a.", "not available"),
]
@@ -724,18 +780,15 @@ class ClusteringTask(dj.Manual):
"""
@classmethod
- def infer_output_dir(
- cls, key, relative: bool = False, mkdir: bool = False
- ) -> pathlib.Path:
+ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False):
"""Infer output directory if it is not provided.
Args:
key (dict): ClusteringTask primary key.
Returns:
- Expected clustering_output_dir based on the following convention:
- processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
- e.g.: sub4/sess1/probe_2/kilosort2_0
+ Pathlib.Path: Expected clustering_output_dir based on the following convention: processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
+ e.g.: sub4/sess1/probe_2/kilosort2_0
"""
processed_dir = pathlib.Path(get_processed_root_data_dir())
session_dir = find_full_path(
@@ -758,7 +811,7 @@ def infer_output_dir(
if mkdir:
output_dir.mkdir(parents=True, exist_ok=True)
- log.info(f"{output_dir} created!")
+ logger.info(f"{output_dir} created!")
return output_dir.relative_to(processed_dir) if relative else output_dir
@@ -809,7 +862,7 @@ class Clustering(dj.Imported):
# Clustering Procedure
-> ClusteringTask
---
- clustering_time: datetime # time of generation of this set of clustering results
+ clustering_time: datetime # time of generation of this set of clustering results
package_version='': varchar(16)
"""
@@ -838,7 +891,7 @@ def make(self, key):
).fetch1("acq_software", "clustering_method", "params")
if "kilosort" in clustering_method:
- from element_array_ephys.readers import kilosort_triggering
+ from .spike_sorting import kilosort_triggering
# add additional probe-recording and channels details into `params`
params = {**params, **get_recording_channels_details(key)}
@@ -850,10 +903,6 @@ def make(self, key):
spikeglx_meta_filepath.parent
)
spikeglx_recording.validate_file("ap")
- run_CatGT = (
- params.pop("run_CatGT", True)
- and "_tcat." not in spikeglx_meta_filepath.stem
- )
if clustering_method.startswith("pykilosort"):
kilosort_triggering.run_pykilosort(
@@ -874,7 +923,7 @@ def make(self, key):
ks_output_dir=kilosort_dir,
params=params,
KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- run_CatGT=run_CatGT,
+ run_CatGT=True,
)
run_kilosort.run_modules()
elif acq_software == "Open Ephys":
@@ -929,7 +978,7 @@ class CuratedClustering(dj.Imported):
definition = """
# Clustering results of the spike sorting step.
- -> Clustering
+ -> Clustering
"""
class Unit(dj.Part):
@@ -946,7 +995,7 @@ class Unit(dj.Part):
spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
"""
- definition = """
+ definition = """
# Properties of a given unit from a round of clustering (and curation)
-> master
unit: int
@@ -956,85 +1005,187 @@ class Unit(dj.Part):
spike_count: int # how many spikes in this recording for this unit
spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
+ spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
"""
def make(self, key):
"""Automated population of Unit information."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ clustering_method, output_dir = (
+ ClusteringTask * ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir")
+ output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+
+ # Get channel and electrode-site mapping
+ electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
+ channel2electrode_map: dict[int, dict] = {
+ chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
+ }
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software, sample_rate = (EphysRecording & key).fetch1(
- "acq_software", "sampling_rate"
- )
+ # Get sorter method and create output directory.
+ sorter_name = clustering_method.replace(".", "_")
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
- sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate)
+ if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs
+ import spikeinterface as si
+ from spikeinterface import sorters
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
+ sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
+ si_sorting_: si.sorters.BaseSorter = si.load_extractor(
+ sorting_file, base_folder=output_dir
)
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
+ if si_sorting_.unit_ids.size == 0:
+ logger.info(
+ f"No units found in {sorting_file}. Skipping Unit ingestion..."
+ )
+ self.insert1(key)
+ return
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / sample_rate
+ sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
+ si_sorting = sorting_analyzer.sorting
+
+ # Find representative channel for each unit
+ unit_peak_channel: dict[int, np.ndarray] = (
+ si.ChannelSparsity.from_best_channels(
+ sorting_analyzer,
+ 1,
+ ).unit_id_to_channel_indices
+ )
+ unit_peak_channel: dict[int, int] = {
+ u: chn[0] for u, chn in unit_peak_channel.items()
+ }
+
+ spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
+ # {unit: spike_count}
+
+ # update channel2electrode_map to match with probe's channel index
+ channel2electrode_map = {
+ idx: channel2electrode_map[int(chn_idx)]
+ for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
+ }
+
+ # Get unit id to quality label mapping
+ cluster_quality_label_map = {
+ int(unit_id): (
+ si_sorting.get_unit_property(unit_id, "KSLabel")
+ if "KSLabel" in si_sorting.get_property_keys()
+ else "n.a."
)
- spike_count = len(unit_spike_times)
+ for unit_id in si_sorting.unit_ids
+ }
+
+ spike_locations = sorting_analyzer.get_extension("spike_locations")
+ extremum_channel_inds = si.template_tools.get_template_extremum_channel(
+ sorting_analyzer, outputs="index"
+ )
+ spikes_df = pd.DataFrame(
+ sorting_analyzer.sorting.to_spike_vector(
+ extremum_channel_inds=extremum_channel_inds
+ )
+ )
+
+ units = []
+ for unit_idx, unit_id in enumerate(si_sorting.unit_ids):
+ unit_id = int(unit_id)
+ unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx]
+ spike_sites = np.array(
+ [
+ channel2electrode_map[chn_idx]["electrode"]
+ for chn_idx in unit_spikes_df.channel_index
+ ]
+ )
+ unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
+ _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
+ spike_times = si_sorting.get_unit_spike_train(
+ unit_id, return_times=True
+ )
+
+ assert len(spike_times) == len(spike_sites) == len(spike_depths)
units.append(
{
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
+ **key,
+ **channel2electrode_map[unit_peak_channel[unit_id]],
+ "unit": unit_id,
+ "cluster_quality_label": cluster_quality_label_map[unit_id],
+ "spike_times": spike_times,
+ "spike_count": spike_count_dict[unit_id],
+ "spike_sites": spike_sites,
+ "spike_depths": spike_depths,
}
)
+ else: # read from kilosort outputs
+ kilosort_dataset = kilosort.Kilosort(output_dir)
+ acq_software, sample_rate = (EphysRecording & key).fetch1(
+ "acq_software", "sampling_rate"
+ )
+
+ sample_rate = kilosort_dataset.data["params"].get(
+ "sample_rate", sample_rate
+ )
+
+ # ---------- Unit ----------
+ # -- Remove 0-spike units
+ withspike_idx = [
+ i
+ for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
+ if (kilosort_dataset.data["spike_clusters"] == u).any()
+ ]
+ valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
+ valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
+
+ # -- Spike-times --
+ # spike_times_sec_adj > spike_times_sec > spike_times
+ spike_time_key = (
+ "spike_times_sec_adj"
+ if "spike_times_sec_adj" in kilosort_dataset.data
+ else (
+ "spike_times_sec"
+ if "spike_times_sec" in kilosort_dataset.data
+ else "spike_times"
+ )
+ )
+ spike_times = kilosort_dataset.data[spike_time_key]
+ kilosort_dataset.extract_spike_depths()
+
+ # -- Spike-sites and Spike-depths --
+ spike_sites = np.array(
+ [
+ channel2electrode_map[s]["electrode"]
+ for s in kilosort_dataset.data["spike_sites"]
+ ]
+ )
+ spike_depths = kilosort_dataset.data["spike_depths"]
+
+ # -- Insert unit, label, peak-chn
+ units = []
+ for unit, unit_lbl in zip(valid_units, valid_unit_labels):
+ if (kilosort_dataset.data["spike_clusters"] == unit).any():
+ unit_channel, _ = kilosort_dataset.get_best_channel(unit)
+ unit_spike_times = (
+ spike_times[kilosort_dataset.data["spike_clusters"] == unit]
+ / sample_rate
+ )
+ spike_count = len(unit_spike_times)
+
+ units.append(
+ {
+ **key,
+ "unit": unit,
+ "cluster_quality_label": unit_lbl,
+ **channel2electrode_map[unit_channel],
+ "spike_times": unit_spike_times,
+ "spike_count": spike_count,
+ "spike_sites": spike_sites[
+ kilosort_dataset.data["spike_clusters"] == unit
+ ],
+ "spike_depths": spike_depths[
+ kilosort_dataset.data["spike_clusters"] == unit
+ ],
+ }
+ )
self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
+ self.Unit.insert(units, ignore_extra_fields=True)
@schema
@@ -1082,116 +1233,180 @@ class Waveform(dj.Part):
# Spike waveforms and their mean across spikes for the given unit
-> master
-> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
+ -> probe.ElectrodeConfig.Electrode
+ ---
waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
"""
def make(self, key):
"""Populates waveform tables."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ clustering_method, output_dir = (
+ ClusteringTask * ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir")
+ output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ self.insert1(key)
+ if not len(CuratedClustering.Unit & key):
+ logger.info(
+ f"No CuratedClustering.Unit found for {key}, skipping Waveform ingestion."
+ )
+ return
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
+ # Get channel and electrode-site mapping
+ electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
+ channel2electrode_map: dict[int, dict] = {
+ chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
+ }
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
+ if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
+ import spikeinterface as si
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
+ sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
+ # Find representative channel for each unit
+ unit_peak_channel: dict[int, np.ndarray] = (
+ si.ChannelSparsity.from_best_channels(
+ sorting_analyzer, 1
+ ).unit_id_to_channel_indices
+ ) # {unit: peak_channel_index}
+ unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
- if (kilosort_dir / "mean_waveforms.npy").exists():
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
+ # update channel2electrode_map to match with probe's channel index
+ channel2electrode_map = {
+ idx: channel2electrode_map[int(chn_idx)]
+ for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
+ }
+
+ templates = sorting_analyzer.get_extension("templates")
def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
+ for unit in (CuratedClustering.Unit & key).fetch(
+ "KEY", order_by="unit"
):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
+ # Get mean waveform for this unit from all channels - (sample x channel)
+ unit_waveforms = templates.get_unit_template(
+ unit_id=unit["unit"], operator="average"
+ )
+ unit_peak_waveform = {
+ **unit,
+ "peak_electrode_waveform": unit_waveforms[
+ :, unit_peak_channel[unit["unit"]]
+ ],
+ }
+
+ unit_electrode_waveforms = [
+ {
+ **unit,
+ **channel2electrode_map[chn_idx],
+ "waveform_mean": unit_waveforms[:, chn_idx],
+ }
+ for chn_idx in channel2electrode_map
+ ]
+
+ yield unit_peak_waveform, unit_electrode_waveforms
+
+ else: # read from kilosort outputs (ecephys pipeline)
+ kilosort_dataset = kilosort.Kilosort(output_dir)
+
+ acq_software, probe_serial_number = (
+ EphysRecording * ProbeInsertion & key
+ ).fetch1("acq_software", "probe")
+
+ # Get all units
+ units = {
+ u["unit"]: u
+ for u in (CuratedClustering.Unit & key).fetch(
+ as_dict=True, order_by="unit"
+ )
+ }
+
+ if (output_dir / "mean_waveforms.npy").exists():
+ unit_waveforms = np.load(
+ output_dir / "mean_waveforms.npy"
+ ) # unit x channel x sample
+
+ def yield_unit_waveforms():
+ for unit_no, unit_waveform in zip(
+ kilosort_dataset.data["cluster_ids"], unit_waveforms
+ ):
+ unit_peak_waveform = {}
+ unit_electrode_waveforms = []
+ if unit_no in units:
+ for channel, channel_waveform in zip(
+ kilosort_dataset.data["channel_map"], unit_waveform
+ ):
+ unit_electrode_waveforms.append(
+ {
+ **units[unit_no],
+ **channel2electrode_map[channel],
+ "waveform_mean": channel_waveform,
+ }
+ )
+ if (
+ channel2electrode_map[channel]["electrode"]
+ == units[unit_no]["electrode"]
+ ):
+ unit_peak_waveform = {
+ **units[unit_no],
+ "peak_electrode_waveform": channel_waveform,
+ }
+ yield unit_peak_waveform, unit_electrode_waveforms
+
+ else:
+ if acq_software == "SpikeGLX":
+ spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
+ neuropixels_recording = spikeglx.SpikeGLX(
+ spikeglx_meta_filepath.parent
+ )
+ elif acq_software == "Open Ephys":
+ session_dir = find_full_path(
+ get_ephys_root_data_dir(), get_session_directory(key)
+ )
+ openephys_dataset = openephys.OpenEphys(session_dir)
+ neuropixels_recording = openephys_dataset.probes[
+ probe_serial_number
+ ]
+
+ def yield_unit_waveforms():
+ for unit_dict in units.values():
+ unit_peak_waveform = {}
+ unit_electrode_waveforms = []
+
+ spikes = unit_dict["spike_times"]
+ waveforms = neuropixels_recording.extract_spike_waveforms(
+ spikes, kilosort_dataset.data["channel_map"]
+ ) # (sample x channel x spike)
+ waveforms = waveforms.transpose(
+ (1, 2, 0)
+ ) # (channel x spike x sample)
for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
+ kilosort_dataset.data["channel_map"], waveforms
):
unit_electrode_waveforms.append(
{
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
+ **unit_dict,
+ **channel2electrode_map[channel],
+ "waveform_mean": channel_waveform.mean(axis=0),
+ "waveforms": channel_waveform,
}
)
if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
+ channel2electrode_map[channel]["electrode"]
+ == unit_dict["electrode"]
):
unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
+ **unit_dict,
+ "peak_electrode_waveform": channel_waveform.mean(
+ axis=0
+ ),
}
- yield unit_peak_waveform, unit_electrode_waveforms
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
+ yield unit_peak_waveform, unit_electrode_waveforms
# insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
if unit_peak_waveform:
self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
@@ -1209,7 +1424,7 @@ class QualityMetrics(dj.Imported):
definition = """
# Clusters and waveforms metrics
- -> CuratedClustering
+ -> CuratedClustering
"""
class Cluster(dj.Part):
@@ -1234,26 +1449,26 @@ class Cluster(dj.Part):
contamination_rate (float): Frequency of spikes in the refractory period.
"""
- definition = """
+ definition = """
# Cluster metrics for a particular unit
-> master
-> CuratedClustering.Unit
---
- firing_rate=null: float # (Hz) firing rate for a unit
+ firing_rate=null: float # (Hz) firing rate for a unit
snr=null: float # signal-to-noise ratio for a unit
presence_ratio=null: float # fraction of time in which spikes are present
isi_violation=null: float # rate of ISI violation as a fraction of overall rate
number_violation=null: int # total number of ISI violations
amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
+ l_ratio=null: float #
d_prime=null: float # Classification accuracy based on LDA
nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
silhouette_score=null: float # Standard metric for cluster overlap
max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
+ cumulative_drift=null: float # Cumulative change in spike depth throughout recording
+ contamination_rate=null: float #
"""
class Waveform(dj.Part):
@@ -1273,13 +1488,13 @@ class Waveform(dj.Part):
velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
"""
- definition = """
+ definition = """
# Waveform metrics for a particular unit
-> master
-> CuratedClustering.Unit
---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
+ amplitude=null: float # (uV) absolute difference between waveform peak and trough
+ duration=null: float # (ms) time between waveform peak and trough
halfwidth=null: float # (ms) spike width at half max amplitude
pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
@@ -1291,30 +1506,75 @@ class Waveform(dj.Part):
def make(self, key):
"""Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ # Load metrics.csv
+ clustering_method, output_dir = (
+ ClusteringTask * ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir")
+ output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
+ self.insert1(key)
+ if not len(CuratedClustering.Unit & key):
+ logger.info(
+ f"No CuratedClustering.Unit found for {key}, skipping QualityMetrics ingestion."
+ )
+ return
+
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
+ if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
+ import spikeinterface as si
+
+ sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
+ qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
+ template_metrics = sorting_analyzer.get_extension(
+ "template_metrics"
+ ).get_data()
+ metrics_df = pd.concat([qc_metrics, template_metrics], axis=1)
+
+ metrics_df.rename(
+ columns={
+ "amplitude_median": "amplitude",
+ "isi_violations_ratio": "isi_violation",
+ "isi_violations_count": "number_violation",
+ "silhouette": "silhouette_score",
+ "rp_contamination": "contamination_rate",
+ "drift_ptp": "max_drift",
+ "drift_mad": "cumulative_drift",
+ "half_width": "halfwidth",
+ "peak_trough_ratio": "pt_ratio",
+ "peak_to_valley": "duration",
+ },
+ inplace=True,
+ )
+ else: # read from kilosort outputs (ecephys pipeline)
+ # find metric_fp
+ for metric_fp in [
+ output_dir / "metrics.csv",
+ ]:
+ if metric_fp.exists():
+ break
+ else:
+ raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")
+
+ metrics_df = pd.read_csv(metric_fp)
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
+ # Conform the dataframe to match the table definition
+ if "cluster_id" in metrics_df.columns:
+ metrics_df.set_index("cluster_id", inplace=True)
+ else:
+ metrics_df.rename(
+ columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
+ )
+ metrics_df.set_index("cluster_id", inplace=True)
+
+ metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
metrics_list = [
dict(metrics_df.loc[unit_key["unit"]], **unit_key)
for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
]
- self.insert1(key)
self.Cluster.insert(metrics_list, ignore_extra_fields=True)
self.Waveform.insert(metrics_list, ignore_extra_fields=True)
@@ -1382,99 +1642,6 @@ def get_openephys_probe_data(ephys_recording_key: dict) -> list:
return probe_data
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- probe_dataset = get_openephys_probe_data(ephys_recording_key)
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_indices"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
-
-
def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
"""Get details of recording channels for a given recording."""
channels_details = {}
@@ -1530,3 +1697,41 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
)
return channels_details
+
+
+def generate_electrode_config_entry(probe_type: str, electrode_keys: list) -> dict:
+ """Generate and insert new ElectrodeConfig
+
+ Args:
+ probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
+ electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
+
+ Returns:
+ dict: representing a key of the probe.ElectrodeConfig table
+ """
+ # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
+ electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
+
+ electrode_list = sorted([k["electrode"] for k in electrode_keys])
+ electrode_gaps = (
+ [-1]
+ + np.where(np.diff(electrode_list) > 1)[0].tolist()
+ + [len(electrode_list) - 1]
+ )
+ electrode_config_name = "; ".join(
+ [
+ f"{electrode_list[start + 1]}-{electrode_list[end]}"
+ for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
+ ]
+ )
+ electrode_config_key = {"electrode_config_hash": electrode_config_hash}
+ econfig_entry = {
+ **electrode_config_key,
+ "probe_type": probe_type,
+ "electrode_config_name": electrode_config_name,
+ }
+ econfig_electrodes = [
+ {**electrode, **electrode_config_key} for electrode in electrode_keys
+ ]
+
+ return econfig_entry, econfig_electrodes
diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py
deleted file mode 100644
index f93a66a4..00000000
--- a/element_array_ephys/ephys_acute.py
+++ /dev/null
@@ -1,1593 +0,0 @@
-import gc
-import importlib
-import inspect
-import pathlib
-import re
-from decimal import Decimal
-
-import datajoint as dj
-import numpy as np
-import pandas as pd
-from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-
-from . import ephys_report, probe
-from .readers import kilosort, openephys, spikeglx
-
-log = dj.logger
-
-schema = dj.schema()
-
-_linking_module = None
-
-
-def activate(
- ephys_schema_name: str,
- probe_schema_name: str = None,
- *,
- create_schema: bool = True,
- create_tables: bool = True,
- linking_module: str = None,
-):
- """Activates the `ephys` and `probe` schemas.
-
- Args:
- ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
- create_schema (bool): If True, schema will be created in the database.
- create_tables (bool): If True, tables related to the schema will be created in the database.
- linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
-
- Dependencies:
- Upstream tables:
- Session: A parent table to ProbeInsertion
- Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported.
-
- Functions:
- get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
- get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
- get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
- """
-
- if isinstance(linking_module, str):
- linking_module = importlib.import_module(linking_module)
- assert inspect.ismodule(
- linking_module
- ), "The argument 'dependency' must be a module's name or a module"
-
- global _linking_module
- _linking_module = linking_module
-
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
- schema.activate(
- ephys_schema_name,
- create_schema=create_schema,
- create_tables=create_tables,
- add_objects=_linking_module.__dict__,
- )
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
-
-
-# -------------- Functions required by the elements-ephys ---------------
-
-
-def get_ephys_root_data_dir() -> list:
- """Fetches absolute data path to ephys data directories.
-
- The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
-
- Returns:
- A list of the absolute path(s) to ephys data directories.
- """
- root_directories = _linking_module.get_ephys_root_data_dir()
- if isinstance(root_directories, (str, pathlib.Path)):
- root_directories = [root_directories]
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- root_directories.append(_linking_module.get_processed_root_data_dir())
-
- return root_directories
-
-
-def get_session_directory(session_key: dict) -> str:
- """Retrieve the session directory with Neuropixels for the given session.
-
- Args:
- session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
-
- Returns:
- A string for the path to the session directory.
- """
- return _linking_module.get_session_directory(session_key)
-
-
-def get_processed_root_data_dir() -> str:
- """Retrieve the root directory for all processed data.
-
- Returns:
- A string for the full path to the root directory for processed data.
- """
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- return _linking_module.get_processed_root_data_dir()
- else:
- return get_ephys_root_data_dir()[0]
-
-
-# ----------------------------- Table declarations ----------------------
-
-
-@schema
-class AcquisitionSoftware(dj.Lookup):
- """Name of software used for recording electrophysiological data.
-
- Attributes:
- acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
- """
-
- definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys
- acq_software: varchar(24)
- """
- contents = zip(["SpikeGLX", "Open Ephys"])
-
-
-@schema
-class ProbeInsertion(dj.Manual):
- """Information about probe insertion across subjects and sessions.
-
- Attributes:
- Session (foreign key): Session primary key.
- insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session.
- probe.Probe (str): probe.Probe primary key.
- """
-
- definition = """
- # Probe insertion implanted into an animal for a given session.
- -> Session
- insertion_number: tinyint unsigned
- ---
- -> probe.Probe
- """
-
- @classmethod
- def auto_generate_entries(cls, session_key):
- """Automatically populate entries in ProbeInsertion table for a session."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(session_key)
- )
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}"
- )
-
- probe_list, probe_insertion_list = [], []
- if acq_software == "SpikeGLX":
- for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths):
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
-
- probe_key = {
- "probe_type": spikeglx_meta.probe_model,
- "probe": spikeglx_meta.probe_SN,
- }
- if probe_key["probe"] not in [p["probe"] for p in probe_list]:
- probe_list.append(probe_key)
-
- probe_dir = meta_filepath.parent
- try:
- probe_number = re.search(r"(imec)?\d{1}$", probe_dir.name).group()
- probe_number = int(probe_number.replace("imec", ""))
- except AttributeError:
- probe_number = meta_fp_idx
-
- probe_insertion_list.append(
- {
- **session_key,
- "probe": spikeglx_meta.probe_SN,
- "insertion_number": int(probe_number),
- }
- )
- elif acq_software == "Open Ephys":
- loaded_oe = openephys.OpenEphys(session_dir)
- for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()):
- probe_key = {
- "probe_type": oe_probe.probe_model,
- "probe": oe_probe.probe_SN,
- }
- if probe_key["probe"] not in [p["probe"] for p in probe_list]:
- probe_list.append(probe_key)
- probe_insertion_list.append(
- {
- **session_key,
- "probe": oe_probe.probe_SN,
- "insertion_number": probe_idx,
- }
- )
- else:
- raise NotImplementedError(f"Unknown acquisition software: {acq_software}")
-
- probe.Probe.insert(probe_list, skip_duplicates=True)
- cls.insert(probe_insertion_list, skip_duplicates=True)
-
-
-@schema
-class InsertionLocation(dj.Manual):
- """Stereotaxic location information for each probe insertion.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- SkullReference (dict): SkullReference primary key.
- ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive.
- ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive.
- depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative.
- Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis.
- phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis.
- """
-
- definition = """
- # Brain Location of a given probe insertion.
- -> ProbeInsertion
- ---
- -> SkullReference
- ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
- ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
- depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
- theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis
- phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis
- beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior
- """
-
-
-@schema
-class EphysRecording(dj.Imported):
- """Automated table with electrophysiology recording information for each probe inserted during an experimental session.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key.
- AcquisitionSoftware (dict): AcquisitionSoftware primary key.
- sampling_rate (float): sampling rate of the recording in Hertz (Hz).
- recording_datetime (datetime): datetime of the recording from this probe.
- recording_duration (float): duration of the entire recording from this probe in seconds.
- """
-
- definition = """
- # Ephys recording from a probe insertion for a given session.
- -> ProbeInsertion
- ---
- -> probe.ElectrodeConfig
- -> AcquisitionSoftware
- sampling_rate: float # (Hz)
- recording_datetime: datetime # datetime of the recording from this probe
- recording_duration: float # (seconds) duration of the recording from this probe
- """
-
- class EphysFile(dj.Part):
- """Paths of electrophysiology recording files for each insertion.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- file_path (varchar(255) ): relative file path for electrophysiology recording.
- """
-
- definition = """
- # Paths of files of a given EphysRecording round.
- -> master
- file_path: varchar(255) # filepath relative to root data directory
- """
-
- def make(self, key):
- """Populates table with electrophysiology recording information."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
- )
-
- supported_probe_types = probe.ProbeType.fetch("probe_type")
-
- if acq_software == "SpikeGLX":
- for meta_filepath in ephys_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(key)
- )
-
- if spikeglx_meta.probe_model in supported_probe_types:
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
- }
- )
-
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
- elif acq_software == "Open Ephys":
- dataset = openephys.OpenEphys(session_dir)
- for serial_number, probe_data in dataset.probes.items():
- if str(serial_number) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No Open Ephys data found for probe insertion: {}".format(key)
- )
-
- if not probe_data.ap_meta:
- raise IOError(
- 'No analog signals found - check "structure.oebin" file or "continuous" directory'
- )
-
- if probe_data.probe_model in supported_probe_types:
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_indices"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
- )
-
- root_dir = find_root_directory(
- get_ephys_root_data_dir(),
- probe_data.recording_info["recording_files"][0],
- )
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- # explicitly garbage collect "dataset"
- # as these may have large memory footprint and may not be cleared fast enough
- del probe_data, dataset
- gc.collect()
- else:
- raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
- )
-
-
-@schema
-class LFP(dj.Imported):
- """Extracts local field potentials (LFP) from an electrophysiology recording.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- lfp_sampling_rate (float): Sampling rate for LFPs in Hz.
- lfp_time_stamps (longblob): Time stamps with respect to the start of the recording.
- lfp_mean (longblob): Overall mean LFP across electrodes.
- """
-
- definition = """
- # Acquired local field potential (LFP) from a given Ephys recording.
- -> EphysRecording
- ---
- lfp_sampling_rate: float # (Hz)
- lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp)
- lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,)
- """
-
- class Electrode(dj.Part):
- """Saves local field potential data for each electrode.
-
- Attributes:
- LFP (foreign key): LFP primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- lfp (longblob): LFP recording at this electrode in microvolts.
- """
-
- definition = """
- -> master
- -> probe.ElectrodeConfig.Electrode
- ---
- lfp: longblob # (uV) recorded lfp at this electrode
- """
-
- # Only store LFP for every 9th channel, due to high channel density,
- # close-by channels exhibit highly similar LFP
- _skip_channel_counts = 9
-
- def make(self, key):
- """Populates the LFP tables."""
- acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software")
-
- electrode_keys, lfp = [], []
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
-
- lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[
- -1 :: -self._skip_channel_counts
- ]
-
- # Extract LFP data at specified channels and convert to uV
- lfp = spikeglx_recording.lf_timeseries[
- :, lfp_channel_ind
- ] # (sample x channel)
- lfp = (
- lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind]
- ).T # (channel x sample)
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"],
- lfp_time_stamps=(
- np.arange(lfp.shape[1])
- / spikeglx_recording.lfmeta.meta["imSampRate"]
- ),
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- for recorded_site in lfp_channel_ind:
- shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[
- "data"
- ][recorded_site]
- electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)])
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- lfp_channel_ind = np.r_[
- len(oe_probe.lfp_meta["channels_indices"])
- - 1 : 0 : -self._skip_channel_counts
- ]
-
- lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind]
- lfp = (
- lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind]
- ).T # (channel x sample)
- lfp_timestamps = oe_probe.lfp_timestamps
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"],
- lfp_time_stamps=lfp_timestamps,
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_keys.extend(
- probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind
- )
- else:
- raise NotImplementedError(
- f"LFP extraction from acquisition software"
- f" of type {acq_software} is not yet implemented"
- )
-
- # single insert in loop to mitigate potential memory issue
- for electrode_key, lfp_trace in zip(electrode_keys, lfp):
- self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace})
-
-
-# ------------ Clustering --------------
-
-
-@schema
-class ClusteringMethod(dj.Lookup):
- """Kilosort clustering method.
-
- Attributes:
- clustering_method (foreign key, varchar(16) ): Kilosort clustering method.
- clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
- """
-
- definition = """
- # Method for clustering
- clustering_method: varchar(16)
- ---
- clustering_method_desc: varchar(1000)
- """
-
- contents = [
- ("kilosort2", "kilosort2 clustering method"),
- ("kilosort2.5", "kilosort2.5 clustering method"),
- ("kilosort3", "kilosort3 clustering method"),
- ]
-
-
-@schema
-class ClusteringParamSet(dj.Lookup):
- """Parameters to be used in clustering procedure for spike sorting.
-
- Attributes:
- paramset_idx (foreign key): Unique ID for the clustering parameter set.
- ClusteringMethod (dict): ClusteringMethod primary key.
- paramset_desc (varchar(128) ): Description of the clustering parameter set.
- param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Parameters for clustering with Kilosort.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> ClusteringMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls,
- clustering_method: str,
- paramset_desc: str,
- params: dict,
- paramset_idx: int = None,
- ):
- """Inserts new parameters into the ClusteringParamSet table.
-
- Args:
- clustering_method (str): name of the clustering method.
- paramset_desc (str): description of the parameter set
- params (dict): clustering parameters
- paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
- """
- if paramset_idx is None:
- paramset_idx = (
- dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0
- ) + 1
-
- param_dict = {
- "clustering_method": clustering_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(
- {**params, "clustering_method": clustering_method}
- ),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- f"The specified param-set already exists"
- f" - with paramset_idx: {existing_paramset_idx}"
- )
- else:
- if {"paramset_idx": paramset_idx} in cls.proj():
- raise dj.DataJointError(
- f"The specified paramset_idx {paramset_idx} already exists,"
- f" please pick a different one."
- )
- cls.insert1(param_dict)
-
-
-@schema
-class ClusterQualityLabel(dj.Lookup):
- """Quality label for each spike sorted cluster.
-
- Attributes:
- cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
- cluster_quality_description ( varchar(4000) ): Description of the cluster quality type.
- """
-
- definition = """
- # Quality
- cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc.
- ---
- cluster_quality_description: varchar(4000)
- """
- contents = [
- ("good", "single unit"),
- ("ok", "probably a single unit, but could be contaminated"),
- ("mua", "multi-unit activity"),
- ("noise", "bad unit"),
- ]
-
-
-@schema
-class ClusteringTask(dj.Manual):
- """A clustering task to spike sort electrophysiology datasets.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- ClusteringParamSet (foreign key): ClusteringParamSet primary key.
- clustering_output_dir ( varchar (255) ): Relative path to output clustering results.
- task_mode (enum): `Trigger` computes clustering or and `load` imports existing data.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> EphysRecording
- -> ClusteringParamSet
- ---
- clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory
- task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
- """
-
- @classmethod
- def infer_output_dir(
- cls, key: dict, relative: bool = False, mkdir: bool = False
- ) -> pathlib.Path:
- """Infer output directory if it is not provided.
-
- Args:
- key (dict): ClusteringTask primary key.
-
- Returns:
- Expected clustering_output_dir based on the following convention:
- processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
- e.g.: sub4/sess1/probe_2/kilosort2_0
- """
- processed_dir = pathlib.Path(get_processed_root_data_dir())
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir)
-
- method = (
- (ClusteringParamSet * ClusteringMethod & key)
- .fetch1("clustering_method")
- .replace(".", "-")
- )
-
- output_dir = (
- processed_dir
- / session_dir.relative_to(root_dir)
- / f'probe_{key["insertion_number"]}'
- / f'{method}_{key["paramset_idx"]}'
- )
-
- if mkdir:
- output_dir.mkdir(parents=True, exist_ok=True)
- log.info(f"{output_dir} created!")
-
- return output_dir.relative_to(processed_dir) if relative else output_dir
-
- @classmethod
- def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0):
- """Autogenerate entries based on a particular ephys recording.
-
- Args:
- ephys_recording_key (dict): EphysRecording primary key.
- paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0.
- """
- key = {**ephys_recording_key, "paramset_idx": paramset_idx}
-
- processed_dir = get_processed_root_data_dir()
- output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True)
-
- try:
- kilosort.Kilosort(
- output_dir
- ) # check if the directory is a valid Kilosort output
- except FileNotFoundError:
- task_mode = "trigger"
- else:
- task_mode = "load"
-
- cls.insert1(
- {
- **key,
- "clustering_output_dir": output_dir.relative_to(
- processed_dir
- ).as_posix(),
- "task_mode": task_mode,
- }
- )
-
-
-@schema
-class Clustering(dj.Imported):
- """A processing table to handle each clustering task.
-
- Attributes:
- ClusteringTask (foreign key): ClusteringTask primary key.
- clustering_time (datetime): Time when clustering results are generated.
- package_version ( varchar(16) ): Package version used for a clustering analysis.
- """
-
- definition = """
- # Clustering Procedure
- -> ClusteringTask
- ---
- clustering_time: datetime # time of generation of this set of clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Triggers or imports clustering analysis."""
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
-
- if not output_dir:
- output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True)
- # update clustering_output_dir
- ClusteringTask.update1(
- {**key, "clustering_output_dir": output_dir.as_posix()}
- )
-
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "load":
- kilosort.Kilosort(
- kilosort_dir
- ) # check if the directory is a valid Kilosort output
- elif task_mode == "trigger":
- acq_software, clustering_method, params = (
- ClusteringTask * EphysRecording * ClusteringParamSet & key
- ).fetch1("acq_software", "clustering_method", "params")
-
- if "kilosort" in clustering_method:
- from element_array_ephys.readers import kilosort_triggering
-
- # add additional probe-recording and channels details into `params`
- params = {**params, **get_recording_channels_details(key)}
- params["fs"] = params["sample_rate"]
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(
- spikeglx_meta_filepath.parent
- )
- spikeglx_recording.validate_file("ap")
- run_CatGT = (
- params.pop("run_CatGT", True)
- and "_tcat." not in spikeglx_meta_filepath.stem
- )
-
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=spikeglx_recording.root_dir
- / (spikeglx_recording.root_name + ".ap.bin"),
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.SGLXKilosortPipeline(
- npx_input_dir=spikeglx_meta_filepath.parent,
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- run_CatGT=run_CatGT,
- )
- run_kilosort.run_modules()
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- assert len(oe_probe.recording_info["recording_files"]) == 1
-
- # run kilosort
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=pathlib.Path(
- oe_probe.recording_info["recording_files"][0]
- )
- / "continuous.dat",
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline(
- npx_input_dir=oe_probe.recording_info["recording_files"][0],
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- )
- run_kilosort.run_modules()
- else:
- raise NotImplementedError(
- f"Automatic triggering of {clustering_method}"
- f" clustering analysis is not yet supported"
- )
-
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir)
- self.insert1({**key, "clustering_time": creation_time, "package_version": ""})
-
-
-@schema
-class Curation(dj.Manual):
- """Curation procedure table.
-
- Attributes:
- Clustering (foreign key): Clustering primary key.
- curation_id (foreign key, int): Unique curation ID.
- curation_time (datetime): Time when curation results are generated.
- curation_output_dir ( varchar(255) ): Output directory of the curated results.
- quality_control (bool): If True, this clustering result has undergone quality control.
- manual_curation (bool): If True, manual curation has been performed on this clustering result.
- curation_note ( varchar(2000) ): Notes about the curation task.
- """
-
- definition = """
- # Manual curation procedure
- -> Clustering
- curation_id: int
- ---
- curation_time: datetime # time of generation of this set of curated clustering results
- curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
- quality_control: bool # has this clustering result undergone quality control?
- manual_curation: bool # has manual curation been performed on this clustering result?
- curation_note='': varchar(2000)
- """
-
- def create1_from_clustering_task(self, key, curation_note=""):
- """
- A function to create a new corresponding "Curation" for a particular
- "ClusteringTask"
- """
- if key not in Clustering():
- raise ValueError(
- f"No corresponding entry in Clustering available"
- f" for: {key}; do `Clustering.populate(key)`"
- )
-
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
- kilosort_dir
- )
- # Synthesize curation_id
- curation_id = (
- dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n")
- )
- self.insert1(
- {
- **key,
- "curation_id": curation_id,
- "curation_time": creation_time,
- "curation_output_dir": output_dir,
- "quality_control": is_qc,
- "manual_curation": is_curated,
- "curation_note": curation_note,
- }
- )
-
-
-@schema
-class CuratedClustering(dj.Imported):
- """Clustering results after curation.
-
- Attributes:
- Curation (foreign key): Curation primary key.
- """
-
- definition = """
- # Clustering results of a curation.
- -> Curation
- """
-
- class Unit(dj.Part):
- """Single unit properties after clustering and curation.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- unit (foreign key, int): Unique integer identifying a single unit.
- probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key.
- ClusteringQualityLabel (dict): CLusteringQualityLabel primary key.
- spike_count (int): Number of spikes in this recording for this unit.
- spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
- spike_sites (longblob): Array of electrode associated with each spike.
- spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
- """
-
- definition = """
- # Properties of a given unit from a round of clustering (and curation)
- -> master
- unit: int
- ---
- -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
- -> ClusterQualityLabel
- spike_count: int # how many spikes in this recording for this unit
- spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
- spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
- """
-
- def make(self, key):
- """Automated population of Unit information."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software, sample_rate = (EphysRecording & key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate)
-
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
- )
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
-
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / sample_rate
- )
- spike_count = len(unit_spike_times)
-
- units.append(
- {
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": (
- spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ]
- if spike_depths is not None
- else None
- ),
- }
- )
-
- self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
-
-
-@schema
-class WaveformSet(dj.Imported):
- """A set of spike waveforms for units out of a given CuratedClustering.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # A set of spike waveforms for units out of a given CuratedClustering
- -> CuratedClustering
- """
-
- class PeakWaveform(dj.Part):
- """Mean waveform across spikes for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
- """
-
- definition = """
- # Mean waveform across spikes for a given unit at its representative electrode
- -> master
- -> CuratedClustering.Unit
- ---
- peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
- """
-
- class Waveform(dj.Part):
- """Spike waveforms for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
- waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
- """
-
- definition = """
- # Spike waveforms and their mean across spikes for the given unit
- -> master
- -> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
- waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
- """
-
- def make(self, key):
- """Populates waveform tables."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
-
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
-
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
-
- is_qc = (Curation & key).fetch1("quality_control")
-
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
-
- if is_qc:
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
-
- def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
- ):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
- ):
- unit_electrode_waveforms.append(
- {
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
- ):
- unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
- }
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
-
- # insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
- for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
- if unit_peak_waveform:
- self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
- if unit_electrode_waveforms:
- self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
-
-
-@schema
-class QualityMetrics(dj.Imported):
- """Clustering and waveform quality metrics.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # Clusters and waveforms metrics
- -> CuratedClustering
- """
-
- class Cluster(dj.Part):
- """Cluster metrics for a unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- firing_rate (float): Firing rate of the unit.
- snr (float): Signal-to-noise ratio for a unit.
- presence_ratio (float): Fraction of time where spikes are present.
- isi_violation (float): rate of ISI violation as a fraction of overall rate.
- number_violation (int): Total ISI violations.
- amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
- isolation_distance (float): Distance to nearest cluster.
- l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
- d_prime (float): Classification accuracy based on LDA.
- nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
- nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
- silhouette_core (float): Maximum change in spike depth throughout recording.
- cumulative_drift (float): Cumulative change in spike depth throughout recording.
- contamination_rate (float): Frequency of spikes in the refractory period.
- """
-
- definition = """
- # Cluster metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- firing_rate=null: float # (Hz) firing rate for a unit
- snr=null: float # signal-to-noise ratio for a unit
- presence_ratio=null: float # fraction of time in which spikes are present
- isi_violation=null: float # rate of ISI violation as a fraction of overall rate
- number_violation=null: int # total number of ISI violations
- amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
- isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
- d_prime=null: float # Classification accuracy based on LDA
- nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
- nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
- silhouette_score=null: float # Standard metric for cluster overlap
- max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
- """
-
- class Waveform(dj.Part):
- """Waveform metrics for a particular unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- amplitude (float): Absolute difference between waveform peak and trough in microvolts.
- duration (float): Time between waveform peak and trough in milliseconds.
- halfwidth (float): Spike width at half max amplitude.
- pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0.
- repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak.
- recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail.
- spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe.
- velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe.
- velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
- """
-
- definition = """
- # Waveform metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
- halfwidth=null: float # (ms) spike width at half max amplitude
- pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
- repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
- recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail
- spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe
- velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe
- velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe
- """
-
- def make(self, key):
- """Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
-
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
-
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
- metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
- metrics_list = [
- dict(metrics_df.loc[unit_key["unit"]], **unit_key)
- for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
- ]
-
- self.insert1(key)
- self.Cluster.insert(metrics_list, ignore_extra_fields=True)
- self.Waveform.insert(metrics_list, ignore_extra_fields=True)
-
-
-# ---------------- HELPER FUNCTIONS ----------------
-
-
-def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str:
- """Get spikeGLX data filepath."""
- # attempt to retrieve from EphysRecording.EphysFile
- spikeglx_meta_filepath = pathlib.Path(
- (
- EphysRecording.EphysFile
- & ephys_recording_key
- & 'file_path LIKE "%.ap.meta"'
- ).fetch1("file_path")
- )
-
- try:
- spikeglx_meta_filepath = find_full_path(
- get_ephys_root_data_dir(), spikeglx_meta_filepath
- )
- except FileNotFoundError:
- # if not found, search in session_dir again
- if not spikeglx_meta_filepath.exists():
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
-
- spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")]
- for meta_filepath in spikeglx_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- spikeglx_meta_filepath = meta_filepath
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(
- ephys_recording_key
- )
- )
-
- return spikeglx_meta_filepath
-
-
-def get_openephys_probe_data(ephys_recording_key: dict) -> list:
- """Get OpenEphys probe data from file."""
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- loaded_oe = openephys.OpenEphys(session_dir)
- probe_data = loaded_oe.probes[inserted_probe_serial_number]
-
- # explicitly garbage collect "loaded_oe"
- # as these may have large memory footprint and may not be cleared fast enough
- del loaded_oe
- gc.collect()
-
- return probe_data
-
-
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- probe_dataset = get_openephys_probe_data(ephys_recording_key)
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_indices"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
-
-
-def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
- """Get details of recording channels for a given recording."""
- channels_details = {}
-
- acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1(
- "probe_type"
- )
- channels_details["probe_type"] = {
- "neuropixels 1.0 - 3A": "3A",
- "neuropixels 1.0 - 3B": "NP1",
- "neuropixels UHD": "NP1100",
- "neuropixels 2.0 - SS": "NP21",
- "neuropixels 2.0 - MS": "NP24",
- }[probe_type]
-
- electrode_config_key = (
- probe.ElectrodeConfig * EphysRecording & ephys_recording_key
- ).fetch1("KEY")
- (
- channels_details["channel_ind"],
- channels_details["x_coords"],
- channels_details["y_coords"],
- channels_details["shank_ind"],
- ) = (
- probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
- & electrode_config_key
- ).fetch(
- "electrode", "x_coord", "y_coord", "shank"
- )
- channels_details["sample_rate"] = sample_rate
- channels_details["num_channels"] = len(channels_details["channel_ind"])
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0]
- channels_details["connected"] = np.array(
- [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]]
- )
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(ephys_recording_key)
- channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0]
- channels_details["connected"] = np.array(
- [
- int(v == 1)
- for c, v in oe_probe.channels_connected.items()
- if c in channels_details["channel_ind"]
- ]
- )
-
- return channels_details
diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py
deleted file mode 100644
index 772e885f..00000000
--- a/element_array_ephys/ephys_chronic.py
+++ /dev/null
@@ -1,1523 +0,0 @@
-import gc
-import importlib
-import inspect
-import pathlib
-from decimal import Decimal
-
-import datajoint as dj
-import numpy as np
-import pandas as pd
-from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-
-from . import ephys_report, probe
-from .readers import kilosort, openephys, spikeglx
-
-log = dj.logger
-
-schema = dj.schema()
-
-_linking_module = None
-
-
-def activate(
- ephys_schema_name: str,
- probe_schema_name: str = None,
- *,
- create_schema: bool = True,
- create_tables: bool = True,
- linking_module: str = None,
-):
- """Activates the `ephys` and `probe` schemas.
-
- Args:
- ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
- create_schema (bool): If True, schema will be created in the database.
- create_tables (bool): If True, tables related to the schema will be created in the database.
- linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
-
- Dependencies:
- Upstream tables:
- Session: A parent table to ProbeInsertion
- Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported.
-
- Functions:
- get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
- get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
- get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
- """
-
- if isinstance(linking_module, str):
- linking_module = importlib.import_module(linking_module)
- assert inspect.ismodule(
- linking_module
- ), "The argument 'dependency' must be a module's name or a module"
-
- global _linking_module
- _linking_module = linking_module
-
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
- schema.activate(
- ephys_schema_name,
- create_schema=create_schema,
- create_tables=create_tables,
- add_objects=_linking_module.__dict__,
- )
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
-
-
-# -------------- Functions required by the elements-ephys ---------------
-
-
-def get_ephys_root_data_dir() -> list:
- """Fetches absolute data path to ephys data directories.
-
- The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
-
- Returns:
- A list of the absolute path(s) to ephys data directories.
- """
- root_directories = _linking_module.get_ephys_root_data_dir()
- if isinstance(root_directories, (str, pathlib.Path)):
- root_directories = [root_directories]
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- root_directories.append(_linking_module.get_processed_root_data_dir())
-
- return root_directories
-
-
-def get_session_directory(session_key: dict) -> str:
- """Retrieve the session directory with Neuropixels for the given session.
-
- Args:
- session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
-
- Returns:
- A string for the path to the session directory.
- """
- return _linking_module.get_session_directory(session_key)
-
-
-def get_processed_root_data_dir() -> str:
- """Retrieve the root directory for all processed data.
-
- Returns:
- A string for the full path to the root directory for processed data.
- """
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- return _linking_module.get_processed_root_data_dir()
- else:
- return get_ephys_root_data_dir()[0]
-
-
-# ----------------------------- Table declarations ----------------------
-
-
-@schema
-class AcquisitionSoftware(dj.Lookup):
- """Name of software used for recording electrophysiological data.
-
- Attributes:
- acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
- """
-
- definition = """ # Software used for recording of neuropixels probes
- acq_software: varchar(24)
- """
- contents = zip(["SpikeGLX", "Open Ephys"])
-
-
-@schema
-class ProbeInsertion(dj.Manual):
- """Information about probe insertion across subjects and sessions.
-
- Attributes:
- Session (foreign key): Session primary key.
- insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session.
- probe.Probe (str): probe.Probe primary key.
- """
-
- definition = """
- # Probe insertion chronically implanted into an animal.
- -> Subject
- insertion_number: tinyint unsigned
- ---
- -> probe.Probe
- insertion_datetime=null: datetime
- """
-
-
-@schema
-class InsertionLocation(dj.Manual):
- """Stereotaxic location information for each probe insertion.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- SkullReference (dict): SkullReference primary key.
- ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive.
- ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive.
- depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative.
- Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis.
- phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis.
- """
-
- definition = """
- # Brain Location of a given probe insertion.
- -> ProbeInsertion
- ---
- -> SkullReference
- ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
- ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
- depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
- theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis
- phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis
- beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior
- """
-
-
-@schema
-class EphysRecording(dj.Imported):
- """Automated table with electrophysiology recording information for each probe inserted during an experimental session.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key.
- AcquisitionSoftware (dict): AcquisitionSoftware primary key.
- sampling_rate (float): sampling rate of the recording in Hertz (Hz).
- recording_datetime (datetime): datetime of the recording from this probe.
- recording_duration (float): duration of the entire recording from this probe in seconds.
- """
-
- definition = """
- # Ephys recording from a probe insertion for a given session.
- -> Session
- -> ProbeInsertion
- ---
- -> probe.ElectrodeConfig
- -> AcquisitionSoftware
- sampling_rate: float # (Hz)
- recording_datetime: datetime # datetime of the recording from this probe
- recording_duration: float # (seconds) duration of the recording from this probe
- """
-
- class EphysFile(dj.Part):
- """Paths of electrophysiology recording files for each insertion.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- file_path (varchar(255) ): relative file path for electrophysiology recording.
- """
-
- definition = """
- # Paths of files of a given EphysRecording round.
- -> master
- file_path: varchar(255) # filepath relative to root data directory
- """
-
- def make(self, key):
- """Populates table with electrophysiology recording information."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
- )
-
- supported_probe_types = probe.ProbeType.fetch("probe_type")
-
- if acq_software == "SpikeGLX":
- for meta_filepath in ephys_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- f"No SpikeGLX data found for probe insertion: {key}"
- + " The probe serial number does not match."
- )
-
- if spikeglx_meta.probe_model in supported_probe_types:
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
- }
- )
-
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
- elif acq_software == "Open Ephys":
- dataset = openephys.OpenEphys(session_dir)
- for serial_number, probe_data in dataset.probes.items():
- if str(serial_number) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No Open Ephys data found for probe insertion: {}".format(key)
- )
-
- if not probe_data.ap_meta:
- raise IOError(
- 'No analog signals found - check "structure.oebin" file or "continuous" directory'
- )
-
- if probe_data.probe_model in supported_probe_types:
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_indices"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
- )
-
- root_dir = find_root_directory(
- get_ephys_root_data_dir(),
- probe_data.recording_info["recording_files"][0],
- )
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- # explicitly garbage collect "dataset"
- # as these may have large memory footprint and may not be cleared fast enough
- del probe_data, dataset
- gc.collect()
- else:
- raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
- )
-
-
-@schema
-class LFP(dj.Imported):
- """Extracts local field potentials (LFP) from an electrophysiology recording.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- lfp_sampling_rate (float): Sampling rate for LFPs in Hz.
- lfp_time_stamps (longblob): Time stamps with respect to the start of the recording.
- lfp_mean (longblob): Overall mean LFP across electrodes.
- """
-
- definition = """
- # Acquired local field potential (LFP) from a given Ephys recording.
- -> EphysRecording
- ---
- lfp_sampling_rate: float # (Hz)
- lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp)
- lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,)
- """
-
- class Electrode(dj.Part):
- """Saves local field potential data for each electrode.
-
- Attributes:
- LFP (foreign key): LFP primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- lfp (longblob): LFP recording at this electrode in microvolts.
- """
-
- definition = """
- -> master
- -> probe.ElectrodeConfig.Electrode
- ---
- lfp: longblob # (uV) recorded lfp at this electrode
- """
-
- # Only store LFP for every 9th channel, due to high channel density,
- # close-by channels exhibit highly similar LFP
- _skip_channel_counts = 9
-
- def make(self, key):
- """Populates the LFP tables."""
- acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software")
-
- electrode_keys, lfp = [], []
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
-
- lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[
- -1 :: -self._skip_channel_counts
- ]
-
- # Extract LFP data at specified channels and convert to uV
- lfp = spikeglx_recording.lf_timeseries[
- :, lfp_channel_ind
- ] # (sample x channel)
- lfp = (
- lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind]
- ).T # (channel x sample)
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"],
- lfp_time_stamps=(
- np.arange(lfp.shape[1])
- / spikeglx_recording.lfmeta.meta["imSampRate"]
- ),
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- for recorded_site in lfp_channel_ind:
- shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[
- "data"
- ][recorded_site]
- electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)])
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- lfp_channel_ind = np.r_[
- len(oe_probe.lfp_meta["channels_indices"])
- - 1 : 0 : -self._skip_channel_counts
- ]
-
- # (sample x channel)
- lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind]
- lfp = (
- lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind]
- ).T # (channel x sample)
- lfp_timestamps = oe_probe.lfp_timestamps
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"],
- lfp_time_stamps=lfp_timestamps,
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_keys.extend(
- probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind
- )
- else:
- raise NotImplementedError(
- f"LFP extraction from acquisition software"
- f" of type {acq_software} is not yet implemented"
- )
-
- # single insert in loop to mitigate potential memory issue
- for electrode_key, lfp_trace in zip(electrode_keys, lfp):
- self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace})
-
-
-# ------------ Clustering --------------
-
-
-@schema
-class ClusteringMethod(dj.Lookup):
- """Kilosort clustering method.
-
- Attributes:
- clustering_method (foreign key, varchar(16) ): Kilosort clustering method.
- clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
- """
-
- definition = """
- # Method for clustering
- clustering_method: varchar(16)
- ---
- clustering_method_desc: varchar(1000)
- """
-
- contents = [
- ("kilosort2", "kilosort2 clustering method"),
- ("kilosort2.5", "kilosort2.5 clustering method"),
- ("kilosort3", "kilosort3 clustering method"),
- ]
-
-
-@schema
-class ClusteringParamSet(dj.Lookup):
- """Parameters to be used in clustering procedure for spike sorting.
-
- Attributes:
- paramset_idx (foreign key): Unique ID for the clustering parameter set.
- ClusteringMethod (dict): ClusteringMethod primary key.
- paramset_desc (varchar(128) ): Description of the clustering parameter set.
- param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Parameters for clustering with Kilosort.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> ClusteringMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls,
- clustering_method: str,
- paramset_desc: str,
- params: dict,
- paramset_idx: int = None,
- ):
- """Inserts new parameters into the ClusteringParamSet table.
-
- Args:
- clustering_method (str): name of the clustering method.
- paramset_desc (str): description of the parameter set
- params (dict): clustering parameters
- paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
- """
- if paramset_idx is None:
- paramset_idx = (
- dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0
- ) + 1
-
- param_dict = {
- "clustering_method": clustering_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(
- {**params, "clustering_method": clustering_method}
- ),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- f"The specified param-set already exists"
- f" - with paramset_idx: {existing_paramset_idx}"
- )
- else:
- if {"paramset_idx": paramset_idx} in cls.proj():
- raise dj.DataJointError(
- f"The specified paramset_idx {paramset_idx} already exists,"
- f" please pick a different one."
- )
- cls.insert1(param_dict)
-
-
-@schema
-class ClusterQualityLabel(dj.Lookup):
- """Quality label for each spike sorted cluster.
-
- Attributes:
- cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
- cluster_quality_description (varchar(4000) ): Description of the cluster quality type.
- """
-
- definition = """
- # Quality
- cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc.
- ---
- cluster_quality_description: varchar(4000)
- """
- contents = [
- ("good", "single unit"),
- ("ok", "probably a single unit, but could be contaminated"),
- ("mua", "multi-unit activity"),
- ("noise", "bad unit"),
- ]
-
-
-@schema
-class ClusteringTask(dj.Manual):
- """A clustering task to spike sort electrophysiology datasets.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- ClusteringParamSet (foreign key): ClusteringParamSet primary key.
- clustering_outdir_dir (varchar (255) ): Relative path to output clustering results.
- task_mode (enum): `Trigger` computes clustering or and `load` imports existing data.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> EphysRecording
- -> ClusteringParamSet
- ---
- clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory
- task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
- """
-
- @classmethod
- def infer_output_dir(cls, key, relative=False, mkdir=False) -> pathlib.Path:
- """Infer output directory if it is not provided.
-
- Args:
- key (dict): ClusteringTask primary key.
-
- Returns:
- Expected clustering_output_dir based on the following convention:
- processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
- e.g.: sub4/sess1/probe_2/kilosort2_0
- """
- processed_dir = pathlib.Path(get_processed_root_data_dir())
- sess_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(key))
- root_dir = find_root_directory(get_ephys_root_data_dir(), sess_dir)
-
- method = (
- (ClusteringParamSet * ClusteringMethod & key)
- .fetch1("clustering_method")
- .replace(".", "-")
- )
-
- output_dir = (
- processed_dir
- / sess_dir.relative_to(root_dir)
- / f'probe_{key["insertion_number"]}'
- / f'{method}_{key["paramset_idx"]}'
- )
-
- if mkdir:
- output_dir.mkdir(parents=True, exist_ok=True)
- log.info(f"{output_dir} created!")
-
- return output_dir.relative_to(processed_dir) if relative else output_dir
-
- @classmethod
- def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0):
- """Autogenerate entries based on a particular ephys recording.
-
- Args:
- ephys_recording_key (dict): EphysRecording primary key.
- paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0.
- """
- key = {**ephys_recording_key, "paramset_idx": paramset_idx}
-
- processed_dir = get_processed_root_data_dir()
- output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True)
-
- try:
- kilosort.Kilosort(
- output_dir
- ) # check if the directory is a valid Kilosort output
- except FileNotFoundError:
- task_mode = "trigger"
- else:
- task_mode = "load"
-
- cls.insert1(
- {
- **key,
- "clustering_output_dir": output_dir.relative_to(
- processed_dir
- ).as_posix(),
- "task_mode": task_mode,
- }
- )
-
-
-@schema
-class Clustering(dj.Imported):
- """A processing table to handle each clustering task.
-
- Attributes:
- ClusteringTask (foreign key): ClusteringTask primary key.
- clustering_time (datetime): Time when clustering results are generated.
- package_version (varchar(16) ): Package version used for a clustering analysis.
- """
-
- definition = """
- # Clustering Procedure
- -> ClusteringTask
- ---
- clustering_time: datetime # time of generation of this set of clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Triggers or imports clustering analysis."""
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
-
- if not output_dir:
- output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True)
- # update clustering_output_dir
- ClusteringTask.update1(
- {**key, "clustering_output_dir": output_dir.as_posix()}
- )
-
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "load":
- kilosort.Kilosort(
- kilosort_dir
- ) # check if the directory is a valid Kilosort output
- elif task_mode == "trigger":
- acq_software, clustering_method, params = (
- ClusteringTask * EphysRecording * ClusteringParamSet & key
- ).fetch1("acq_software", "clustering_method", "params")
-
- if "kilosort" in clustering_method:
- from element_array_ephys.readers import kilosort_triggering
-
- # add additional probe-recording and channels details into `params`
- params = {**params, **get_recording_channels_details(key)}
- params["fs"] = params["sample_rate"]
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(
- spikeglx_meta_filepath.parent
- )
- spikeglx_recording.validate_file("ap")
- run_CatGT = (
- params.pop("run_CatGT", True)
- and "_tcat." not in spikeglx_meta_filepath.stem
- )
-
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=spikeglx_recording.root_dir
- / (spikeglx_recording.root_name + ".ap.bin"),
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.SGLXKilosortPipeline(
- npx_input_dir=spikeglx_meta_filepath.parent,
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- run_CatGT=run_CatGT,
- )
- run_kilosort.run_modules()
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- assert len(oe_probe.recording_info["recording_files"]) == 1
-
- # run kilosort
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=pathlib.Path(
- oe_probe.recording_info["recording_files"][0]
- )
- / "continuous.dat",
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline(
- npx_input_dir=oe_probe.recording_info["recording_files"][0],
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- )
- run_kilosort.run_modules()
- else:
- raise NotImplementedError(
- f"Automatic triggering of {clustering_method}"
- f" clustering analysis is not yet supported"
- )
-
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir)
- self.insert1({**key, "clustering_time": creation_time, "package_version": ""})
-
-
-@schema
-class Curation(dj.Manual):
- """Curation procedure table.
-
- Attributes:
- Clustering (foreign key): Clustering primary key.
- curation_id (foreign key, int): Unique curation ID.
- curation_time (datetime): Time when curation results are generated.
- curation_output_dir (varchar(255) ): Output directory of the curated results.
- quality_control (bool): If True, this clustering result has undergone quality control.
- manual_curation (bool): If True, manual curation has been performed on this clustering result.
- curation_note (varchar(2000) ): Notes about the curation task.
- """
-
- definition = """
- # Manual curation procedure
- -> Clustering
- curation_id: int
- ---
- curation_time: datetime # time of generation of this set of curated clustering results
- curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
- quality_control: bool # has this clustering result undergone quality control?
- manual_curation: bool # has manual curation been performed on this clustering result?
- curation_note='': varchar(2000)
- """
-
- def create1_from_clustering_task(self, key, curation_note: str = ""):
- """
- A function to create a new corresponding "Curation" for a particular
- "ClusteringTask"
- """
- if key not in Clustering():
- raise ValueError(
- f"No corresponding entry in Clustering available"
- f" for: {key}; do `Clustering.populate(key)`"
- )
-
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
- kilosort_dir
- )
- # Synthesize curation_id
- curation_id = (
- dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n")
- )
- self.insert1(
- {
- **key,
- "curation_id": curation_id,
- "curation_time": creation_time,
- "curation_output_dir": output_dir,
- "quality_control": is_qc,
- "manual_curation": is_curated,
- "curation_note": curation_note,
- }
- )
-
-
-@schema
-class CuratedClustering(dj.Imported):
- """Clustering results after curation.
-
- Attributes:
- Curation (foreign key): Curation primary key.
- """
-
- definition = """
- # Clustering results of a curation.
- -> Curation
- """
-
- class Unit(dj.Part):
- """Single unit properties after clustering and curation.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- unit (foreign key, int): Unique integer identifying a single unit.
- probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key.
- ClusteringQualityLabel (dict): CLusteringQualityLabel primary key.
- spike_count (int): Number of spikes in this recording for this unit.
- spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
- spike_sites (longblob): Array of electrode associated with each spike.
- spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
- """
-
- definition = """
- # Properties of a given unit from a round of clustering (and curation)
- -> master
- unit: int
- ---
- -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
- -> ClusterQualityLabel
- spike_count: int # how many spikes in this recording for this unit
- spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
- spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
- """
-
- def make(self, key):
- """Automated population of Unit information."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software, sample_rate = (EphysRecording & key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate)
-
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
- )
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
-
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / sample_rate
- )
- spike_count = len(unit_spike_times)
-
- units.append(
- {
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": (
- spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ]
- if spike_depths is not None
- else None
- ),
- }
- )
-
- self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
-
-
-@schema
-class WaveformSet(dj.Imported):
- """A set of spike waveforms for units out of a given CuratedClustering.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # A set of spike waveforms for units out of a given CuratedClustering
- -> CuratedClustering
- """
-
- class PeakWaveform(dj.Part):
- """Mean waveform across spikes for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
- """
-
- definition = """
- # Mean waveform across spikes for a given unit at its representative electrode
- -> master
- -> CuratedClustering.Unit
- ---
- peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
- """
-
- class Waveform(dj.Part):
- """Spike waveforms for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
- waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
- """
-
- definition = """
- # Spike waveforms and their mean across spikes for the given unit
- -> master
- -> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
- waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
- """
-
- def make(self, key):
- """Populates waveform tables."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
-
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
-
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
-
- is_qc = (Curation & key).fetch1("quality_control")
-
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
-
- if is_qc:
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
-
- def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
- ):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
- ):
- unit_electrode_waveforms.append(
- {
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
- ):
- unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
- }
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
-
- # insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
- for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
- if unit_peak_waveform:
- self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
- if unit_electrode_waveforms:
- self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
-
-
-@schema
-class QualityMetrics(dj.Imported):
- """Clustering and waveform quality metrics.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # Clusters and waveforms metrics
- -> CuratedClustering
- """
-
- class Cluster(dj.Part):
- """Cluster metrics for a unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- firing_rate (float): Firing rate of the unit.
- snr (float): Signal-to-noise ratio for a unit.
- presence_ratio (float): Fraction of time where spikes are present.
- isi_violation (float): rate of ISI violation as a fraction of overall rate.
- number_violation (int): Total ISI violations.
- amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
- isolation_distance (float): Distance to nearest cluster.
- l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
- d_prime (float): Classification accuracy based on LDA.
- nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
- nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
- silhouette_core (float): Maximum change in spike depth throughout recording.
- cumulative_drift (float): Cumulative change in spike depth throughout recording.
- contamination_rate (float): Frequency of spikes in the refractory period.
- """
-
- definition = """
- # Cluster metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- firing_rate=null: float # (Hz) firing rate for a unit
- snr=null: float # signal-to-noise ratio for a unit
- presence_ratio=null: float # fraction of time in which spikes are present
- isi_violation=null: float # rate of ISI violation as a fraction of overall rate
- number_violation=null: int # total number of ISI violations
- amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
- isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
- d_prime=null: float # Classification accuracy based on LDA
- nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
- nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
- silhouette_score=null: float # Standard metric for cluster overlap
- max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
- """
-
- class Waveform(dj.Part):
- """Waveform metrics for a particular unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- amplitude (float): Absolute difference between waveform peak and trough in microvolts.
- duration (float): Time between waveform peak and trough in milliseconds.
- halfwidth (float): Spike width at half max amplitude.
- pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0.
- repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak.
- recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail.
- spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe.
- velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe.
- velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
- """
-
- definition = """
- # Waveform metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
- halfwidth=null: float # (ms) spike width at half max amplitude
- pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
- repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
- recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail
- spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe
- velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe
- velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe
- """
-
- def make(self, key):
- """Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
-
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
-
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
- metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
- metrics_list = [
- dict(metrics_df.loc[unit_key["unit"]], **unit_key)
- for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
- ]
-
- self.insert1(key)
- self.Cluster.insert(metrics_list, ignore_extra_fields=True)
- self.Waveform.insert(metrics_list, ignore_extra_fields=True)
-
-
-# ---------------- HELPER FUNCTIONS ----------------
-
-
-def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str:
- """Get spikeGLX data filepath."""
- # attempt to retrieve from EphysRecording.EphysFile
- spikeglx_meta_filepath = pathlib.Path(
- (
- EphysRecording.EphysFile
- & ephys_recording_key
- & 'file_path LIKE "%.ap.meta"'
- ).fetch1("file_path")
- )
-
- try:
- spikeglx_meta_filepath = find_full_path(
- get_ephys_root_data_dir(), spikeglx_meta_filepath
- )
- except FileNotFoundError:
- # if not found, search in session_dir again
- if not spikeglx_meta_filepath.exists():
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
-
- spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")]
- for meta_filepath in spikeglx_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- spikeglx_meta_filepath = meta_filepath
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(
- ephys_recording_key
- )
- )
-
- return spikeglx_meta_filepath
-
-
-def get_openephys_probe_data(ephys_recording_key: dict) -> list:
- """Get OpenEphys probe data from file."""
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- loaded_oe = openephys.OpenEphys(session_dir)
- probe_data = loaded_oe.probes[inserted_probe_serial_number]
-
- # explicitly garbage collect "loaded_oe"
- # as these may have large memory footprint and may not be cleared fast enough
- del loaded_oe
- gc.collect()
-
- return probe_data
-
-
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- probe_dataset = get_openephys_probe_data(ephys_recording_key)
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_indices"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
-
-
-def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
- """Get details of recording channels for a given recording."""
- channels_details = {}
-
- acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1(
- "probe_type"
- )
- channels_details["probe_type"] = {
- "neuropixels 1.0 - 3A": "3A",
- "neuropixels 1.0 - 3B": "NP1",
- "neuropixels UHD": "NP1100",
- "neuropixels 2.0 - SS": "NP21",
- "neuropixels 2.0 - MS": "NP24",
- }[probe_type]
-
- electrode_config_key = (
- probe.ElectrodeConfig * EphysRecording & ephys_recording_key
- ).fetch1("KEY")
- (
- channels_details["channel_ind"],
- channels_details["x_coords"],
- channels_details["y_coords"],
- channels_details["shank_ind"],
- ) = (
- probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
- & electrode_config_key
- ).fetch(
- "electrode", "x_coord", "y_coord", "shank"
- )
- channels_details["sample_rate"] = sample_rate
- channels_details["num_channels"] = len(channels_details["channel_ind"])
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0]
- channels_details["connected"] = np.array(
- [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]]
- )
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(ephys_recording_key)
- channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0]
- channels_details["connected"] = np.array(
- [
- int(v == 1)
- for c, v in oe_probe.channels_connected.items()
- if c in channels_details["channel_ind"]
- ]
- )
-
- return channels_details
diff --git a/element_array_ephys/ephys_precluster.py b/element_array_ephys/ephys_precluster.py
deleted file mode 100644
index 4d52c610..00000000
--- a/element_array_ephys/ephys_precluster.py
+++ /dev/null
@@ -1,1435 +0,0 @@
-import importlib
-import inspect
-import re
-
-import datajoint as dj
-import numpy as np
-import pandas as pd
-from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-
-from . import ephys_report, probe
-from .readers import kilosort, openephys, spikeglx
-
-schema = dj.schema()
-
-_linking_module = None
-
-
-def activate(
- ephys_schema_name: str,
- probe_schema_name: str = None,
- *,
- create_schema: bool = True,
- create_tables: bool = True,
- linking_module: str = None,
-):
- """Activates the `ephys` and `probe` schemas.
-
- Args:
- ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
- create_schema (bool): If True, schema will be created in the database.
- create_tables (bool): If True, tables related to the schema will be created in the database.
- linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
-
- Dependencies:
- Upstream tables:
- Session: A parent table to ProbeInsertion
- Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported.
-
- Functions:
- get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
- get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
- """
-
- if isinstance(linking_module, str):
- linking_module = importlib.import_module(linking_module)
- assert inspect.ismodule(
- linking_module
- ), "The argument 'dependency' must be a module's name or a module"
-
- global _linking_module
- _linking_module = linking_module
-
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
- schema.activate(
- ephys_schema_name,
- create_schema=create_schema,
- create_tables=create_tables,
- add_objects=_linking_module.__dict__,
- )
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
-
-
-# -------------- Functions required by the elements-ephys ---------------
-
-
-def get_ephys_root_data_dir() -> list:
- """Fetches absolute data path to ephys data directories.
-
- The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
-
- Returns:
- A list of the absolute path(s) to ephys data directories.
- """
- return _linking_module.get_ephys_root_data_dir()
-
-
-def get_session_directory(session_key: dict) -> str:
- """Retrieve the session directory with Neuropixels for the given session.
-
- Args:
- session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
-
- Returns:
- A string for the path to the session directory.
- """
- return _linking_module.get_session_directory(session_key)
-
-
-# ----------------------------- Table declarations ----------------------
-
-
-@schema
-class AcquisitionSoftware(dj.Lookup):
- """Name of software used for recording electrophysiological data.
-
- Attributes:
- acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
- """
-
- definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys
- acq_software: varchar(24)
- """
- contents = zip(["SpikeGLX", "Open Ephys"])
-
-
-@schema
-class ProbeInsertion(dj.Manual):
- """Information about probe insertion across subjects and sessions.
-
- Attributes:
- Session (foreign key): Session primary key.
- insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session.
- probe.Probe (str): probe.Probe primary key.
- """
-
- definition = """
- # Probe insertion implanted into an animal for a given session.
- -> Session
- insertion_number: tinyint unsigned
- ---
- -> probe.Probe
- """
-
-
-@schema
-class InsertionLocation(dj.Manual):
- """Stereotaxic location information for each probe insertion.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- SkullReference (dict): SkullReference primary key.
- ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive.
- ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive.
- depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative.
- Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis.
- phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis
-
- """
-
- definition = """
- # Brain Location of a given probe insertion.
- -> ProbeInsertion
- ---
- -> SkullReference
- ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
- ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
- depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
- theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis
- phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis
- beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior
- """
-
-
-@schema
-class EphysRecording(dj.Imported):
- """Automated table with electrophysiology recording information for each probe inserted during an experimental session.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key.
- AcquisitionSoftware (dict): AcquisitionSoftware primary key.
- sampling_rate (float): sampling rate of the recording in Hertz (Hz).
- recording_datetime (datetime): datetime of the recording from this probe.
- recording_duration (float): duration of the entire recording from this probe in seconds.
- """
-
- definition = """
- # Ephys recording from a probe insertion for a given session.
- -> ProbeInsertion
- ---
- -> probe.ElectrodeConfig
- -> AcquisitionSoftware
- sampling_rate: float # (Hz)
- recording_datetime: datetime # datetime of the recording from this probe
- recording_duration: float # (seconds) duration of the recording from this probe
- """
-
- class EphysFile(dj.Part):
- """Paths of electrophysiology recording files for each insertion.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- file_path (varchar(255) ): relative file path for electrophysiology recording.
- """
-
- definition = """
- # Paths of files of a given EphysRecording round.
- -> master
- file_path: varchar(255) # filepath relative to root data directory
- """
-
- def make(self, key):
- """Populates table with electrophysiology recording information."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)]
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
- )
-
- if acq_software == "SpikeGLX":
- for meta_filepath in ephys_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(key)
- )
-
- if re.search("(1.0|2.0)", spikeglx_meta.probe_model):
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
- }
- )
-
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
- elif acq_software == "Open Ephys":
- dataset = openephys.OpenEphys(session_dir)
- for serial_number, probe_data in dataset.probes.items():
- if str(serial_number) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No Open Ephys data found for probe insertion: {}".format(key)
- )
-
- if re.search("(1.0|2.0)", probe_data.probe_model):
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_ids"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
- )
-
- root_dir = find_root_directory(
- get_ephys_root_data_dir(),
- probe_data.recording_info["recording_files"][0],
- )
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- else:
- raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
- )
-
-
-@schema
-class PreClusterMethod(dj.Lookup):
- """Pre-clustering method
-
- Attributes:
- precluster_method (foreign key, varchar(16) ): Pre-clustering method for the dataset.
- precluster_method_desc(varchar(1000) ): Pre-clustering method description.
- """
-
- definition = """
- # Method for pre-clustering
- precluster_method: varchar(16)
- ---
- precluster_method_desc: varchar(1000)
- """
-
- contents = [("catgt", "Time shift, Common average referencing, Zeroing")]
-
-
-@schema
-class PreClusterParamSet(dj.Lookup):
- """Parameters for the pre-clustering method.
-
- Attributes:
- paramset_idx (foreign key): Unique parameter set ID.
- PreClusterMethod (dict): PreClusterMethod query for this dataset.
- paramset_desc (varchar(128) ): Description for the pre-clustering parameter set.
- param_set_hash (uuid): Unique hash for parameter set.
- params (longblob): All parameters for the pre-clustering method.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> PreClusterMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls, precluster_method: str, paramset_idx: int, paramset_desc: str, params: dict
- ):
- param_dict = {
- "precluster_method": precluster_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(params),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- "The specified param-set"
- " already exists - paramset_idx: {}".format(existing_paramset_idx)
- )
- else:
- cls.insert1(param_dict)
-
-
-@schema
-class PreClusterParamSteps(dj.Manual):
- """Ordered list of parameter sets that will be run.
-
- Attributes:
- precluster_param_steps_id (foreign key): Unique ID for the pre-clustering parameter sets to be run.
- precluster_param_steps_name (varchar(32) ): User-friendly name for the parameter steps.
- precluster_param_steps_desc (varchar(128) ): Description of the parameter steps.
- """
-
- definition = """
- # Ordered list of paramset_idx that are to be run
- # When pre-clustering is not performed, do not create an entry in `Step` Part table
- precluster_param_steps_id: smallint
- ---
- precluster_param_steps_name: varchar(32)
- precluster_param_steps_desc: varchar(128)
- """
-
- class Step(dj.Part):
- """Define the order of operations for parameter sets.
-
- Attributes:
- PreClusterParamSteps (foreign key): PreClusterParamSteps primary key.
- step_number (foreign key, smallint): Order of operations.
- PreClusterParamSet (dict): PreClusterParamSet to be used in pre-clustering.
- """
-
- definition = """
- -> master
- step_number: smallint # Order of operations
- ---
- -> PreClusterParamSet
- """
-
-
-@schema
-class PreClusterTask(dj.Manual):
- """Defines a pre-clustering task ready to be run.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- PreclusterParamSteps (foreign key): PreClusterParam Steps primary key.
- precluster_output_dir (varchar(255) ): relative path to directory for storing results of pre-clustering.
- task_mode (enum ): `none` (no pre-clustering), `load` results from file, or `trigger` automated pre-clustering.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> EphysRecording
- -> PreClusterParamSteps
- ---
- precluster_output_dir='': varchar(255) # pre-clustering output directory relative to the root data directory
- task_mode='none': enum('none','load', 'trigger') # 'none': no pre-clustering analysis
- # 'load': load analysis results
- # 'trigger': trigger computation
- """
-
-
-@schema
-class PreCluster(dj.Imported):
- """
- A processing table to handle each PreClusterTask:
-
- Attributes:
- PreClusterTask (foreign key): PreClusterTask primary key.
- precluster_time (datetime): Time of generation of this set of pre-clustering results.
- package_version (varchar(16) ): Package version used for performing pre-clustering.
- """
-
- definition = """
- -> PreClusterTask
- ---
- precluster_time: datetime # time of generation of this set of pre-clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Populate pre-clustering tables."""
- task_mode, output_dir = (PreClusterTask & key).fetch1(
- "task_mode", "precluster_output_dir"
- )
- precluster_output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "none":
- if len((PreClusterParamSteps.Step & key).fetch()) > 0:
- raise ValueError(
- "There are entries in the PreClusterParamSteps.Step "
- "table and task_mode=none"
- )
- creation_time = (EphysRecording & key).fetch1("recording_datetime")
- elif task_mode == "load":
- acq_software = (EphysRecording & key).fetch1("acq_software")
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- if acq_software == "SpikeGLX":
- for meta_filepath in precluster_output_dir.rglob("*.ap.meta"):
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
-
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- creation_time = spikeglx_meta.recording_time
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(key)
- )
- else:
- raise NotImplementedError(
- f"Pre-clustering analysis of {acq_software}" "is not yet supported."
- )
- elif task_mode == "trigger":
- raise NotImplementedError(
- "Automatic triggering of"
- " pre-clustering analysis is not yet supported."
- )
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- self.insert1({**key, "precluster_time": creation_time, "package_version": ""})
-
-
-@schema
-class LFP(dj.Imported):
- """Extracts local field potentials (LFP) from an electrophysiology recording.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- lfp_sampling_rate (float): Sampling rate for LFPs in Hz.
- lfp_time_stamps (longblob): Time stamps with respect to the start of the recording.
- lfp_mean (longblob): Overall mean LFP across electrodes.
- """
-
- definition = """
- # Acquired local field potential (LFP) from a given Ephys recording.
- -> PreCluster
- ---
- lfp_sampling_rate: float # (Hz)
- lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp)
- lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,)
- """
-
- class Electrode(dj.Part):
- """Saves local field potential data for each electrode.
-
- Attributes:
- LFP (foreign key): LFP primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- lfp (longblob): LFP recording at this electrode in microvolts.
- """
-
- definition = """
- -> master
- -> probe.ElectrodeConfig.Electrode
- ---
- lfp: longblob # (uV) recorded lfp at this electrode
- """
-
- # Only store LFP for every 9th channel, due to high channel density,
- # close-by channels exhibit highly similar LFP
- _skip_channel_counts = 9
-
- def make(self, key):
- """Populates the LFP tables."""
- acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1(
- "acq_software", "probe"
- )
-
- electrode_keys, lfp = [], []
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
-
- lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[
- -1 :: -self._skip_channel_counts
- ]
-
- # Extract LFP data at specified channels and convert to uV
- lfp = spikeglx_recording.lf_timeseries[
- :, lfp_channel_ind
- ] # (sample x channel)
- lfp = (
- lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind]
- ).T # (channel x sample)
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"],
- lfp_time_stamps=(
- np.arange(lfp.shape[1])
- / spikeglx_recording.lfmeta.meta["imSampRate"]
- ),
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- for recorded_site in lfp_channel_ind:
- shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[
- "data"
- ][recorded_site]
- electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)])
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- loaded_oe = openephys.OpenEphys(session_dir)
- oe_probe = loaded_oe.probes[probe_sn]
-
- lfp_channel_ind = np.arange(len(oe_probe.lfp_meta["channels_ids"]))[
- -1 :: -self._skip_channel_counts
- ]
-
- lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel)
- lfp = (
- lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind]
- ).T # (channel x sample)
- lfp_timestamps = oe_probe.lfp_timestamps
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"],
- lfp_time_stamps=lfp_timestamps,
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- for channel_idx in np.array(oe_probe.lfp_meta["channels_ids"])[
- lfp_channel_ind
- ]:
- electrode_keys.append(probe_electrodes[channel_idx])
- else:
- raise NotImplementedError(
- f"LFP extraction from acquisition software"
- f" of type {acq_software} is not yet implemented"
- )
-
- # single insert in loop to mitigate potential memory issue
- for electrode_key, lfp_trace in zip(electrode_keys, lfp):
- self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace})
-
-
-# ------------ Clustering --------------
-
-
-@schema
-class ClusteringMethod(dj.Lookup):
- """Kilosort clustering method.
-
- Attributes:
- clustering_method (foreign key, varchar(16) ): Kilosort clustering method.
- clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
- """
-
- definition = """
- # Method for clustering
- clustering_method: varchar(16)
- ---
- clustering_method_desc: varchar(1000)
- """
-
- contents = [
- ("kilosort", "kilosort clustering method"),
- ("kilosort2", "kilosort2 clustering method"),
- ]
-
-
-@schema
-class ClusteringParamSet(dj.Lookup):
- """Parameters to be used in clustering procedure for spike sorting.
-
- Attributes:
- paramset_idx (foreign key): Unique ID for the clustering parameter set.
- ClusteringMethod (dict): ClusteringMethod primary key.
- paramset_desc (varchar(128) ): Description of the clustering parameter set.
- param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Paramset, dictionary of all applicable parameters.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> ClusteringMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls, processing_method: str, paramset_idx: int, paramset_desc: str, params: dict
- ):
- """Inserts new parameters into the ClusteringParamSet table.
-
- Args:
- processing_method (str): name of the clustering method.
- paramset_desc (str): description of the parameter set
- params (dict): clustering parameters
- paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
- """
- param_dict = {
- "clustering_method": processing_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(params),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- "The specified param-set"
- " already exists - paramset_idx: {}".format(existing_paramset_idx)
- )
- else:
- cls.insert1(param_dict)
-
-
-@schema
-class ClusterQualityLabel(dj.Lookup):
- """Quality label for each spike sorted cluster.
-
- Attributes:
- cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
- cluster_quality_description (varchar(4000) ): Description of the cluster quality type.
- """
-
- definition = """
- # Quality
- cluster_quality_label: varchar(100)
- ---
- cluster_quality_description: varchar(4000)
- """
- contents = [
- ("good", "single unit"),
- ("ok", "probably a single unit, but could be contaminated"),
- ("mua", "multi-unit activity"),
- ("noise", "bad unit"),
- ]
-
-
-@schema
-class ClusteringTask(dj.Manual):
- """A clustering task to spike sort electrophysiology datasets.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- ClusteringParamSet (foreign key): ClusteringParamSet primary key.
- clustering_outdir_dir (varchar (255) ): Relative path to output clustering results.
- task_mode (enum): `Trigger` computes clustering or and `load` imports existing data.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> PreCluster
- -> ClusteringParamSet
- ---
- clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory
- task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
- """
-
-
-@schema
-class Clustering(dj.Imported):
- """A processing table to handle each clustering task.
-
- Attributes:
- ClusteringTask (foreign key): ClusteringTask primary key.
- clustering_time (datetime): Time when clustering results are generated.
- package_version (varchar(16) ): Package version used for a clustering analysis.
- """
-
- definition = """
- # Clustering Procedure
- -> ClusteringTask
- ---
- clustering_time: datetime # time of generation of this set of clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Triggers or imports clustering analysis."""
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "load":
- _ = kilosort.Kilosort(
- kilosort_dir
- ) # check if the directory is a valid Kilosort output
- creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir)
- elif task_mode == "trigger":
- raise NotImplementedError(
- "Automatic triggering of" " clustering analysis is not yet supported"
- )
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- self.insert1({**key, "clustering_time": creation_time, "package_version": ""})
-
-
-@schema
-class Curation(dj.Manual):
- """Curation procedure table.
-
- Attributes:
- Clustering (foreign key): Clustering primary key.
- curation_id (foreign key, int): Unique curation ID.
- curation_time (datetime): Time when curation results are generated.
- curation_output_dir (varchar(255) ): Output directory of the curated results.
- quality_control (bool): If True, this clustering result has undergone quality control.
- manual_curation (bool): If True, manual curation has been performed on this clustering result.
- curation_note (varchar(2000) ): Notes about the curation task.
- """
-
- definition = """
- # Manual curation procedure
- -> Clustering
- curation_id: int
- ---
- curation_time: datetime # time of generation of this set of curated clustering results
- curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
- quality_control: bool # has this clustering result undergone quality control?
- manual_curation: bool # has manual curation been performed on this clustering result?
- curation_note='': varchar(2000)
- """
-
- def create1_from_clustering_task(self, key, curation_note: str = ""):
- """
- A function to create a new corresponding "Curation" for a particular
- "ClusteringTask"
- """
- if key not in Clustering():
- raise ValueError(
- f"No corresponding entry in Clustering available"
- f" for: {key}; do `Clustering.populate(key)`"
- )
-
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
- kilosort_dir
- )
- # Synthesize curation_id
- curation_id = (
- dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n")
- )
- self.insert1(
- {
- **key,
- "curation_id": curation_id,
- "curation_time": creation_time,
- "curation_output_dir": output_dir,
- "quality_control": is_qc,
- "manual_curation": is_curated,
- "curation_note": curation_note,
- }
- )
-
-
-@schema
-class CuratedClustering(dj.Imported):
- """Clustering results after curation.
-
- Attributes:
- Curation (foreign key): Curation primary key.
- """
-
- definition = """
- # Clustering results of a curation.
- -> Curation
- """
-
- class Unit(dj.Part):
- """Single unit properties after clustering and curation.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- unit (foreign key, int): Unique integer identifying a single unit.
- probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key.
- ClusteringQualityLabel (dict): CLusteringQualityLabel primary key.
- spike_count (int): Number of spikes in this recording for this unit.
- spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
- spike_sites (longblob): Array of electrode associated with each spike.
- spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
- """
-
- definition = """
- # Properties of a given unit from a round of clustering (and curation)
- -> master
- unit: int
- ---
- -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
- -> ClusterQualityLabel
- spike_count: int # how many spikes in this recording for this unit
- spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
- spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
- """
-
- def make(self, key):
- """Automated population of Unit information."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software = (EphysRecording & key).fetch1("acq_software")
-
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
- )
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
-
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / kilosort_dataset.data["params"]["sample_rate"]
- )
- spike_count = len(unit_spike_times)
-
- units.append(
- {
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": (
- spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ]
- if spike_depths is not None
- else None
- ),
- }
- )
-
- self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
-
-
-@schema
-class WaveformSet(dj.Imported):
- """A set of spike waveforms for units out of a given CuratedClustering.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # A set of spike waveforms for units out of a given CuratedClustering
- -> CuratedClustering
- """
-
- class PeakWaveform(dj.Part):
- """Mean waveform across spikes for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
- """
-
- definition = """
- # Mean waveform across spikes for a given unit at its representative electrode
- -> master
- -> CuratedClustering.Unit
- ---
- peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
- """
-
- class Waveform(dj.Part):
- """Spike waveforms for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
- waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
- """
-
- definition = """
- # Spike waveforms and their mean across spikes for the given unit
- -> master
- -> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
- waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
- """
-
- def make(self, key):
- """Populates waveform tables."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
-
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
-
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
-
- is_qc = (Curation & key).fetch1("quality_control")
-
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
-
- if is_qc:
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
-
- def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
- ):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
- ):
- unit_electrode_waveforms.append(
- {
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
- ):
- unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
- }
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
-
- # insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
- for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
- self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
- self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
-
-
-@schema
-class QualityMetrics(dj.Imported):
- """Clustering and waveform quality metrics.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # Clusters and waveforms metrics
- -> CuratedClustering
- """
-
- class Cluster(dj.Part):
- """Cluster metrics for a unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- firing_rate (float): Firing rate of the unit.
- snr (float): Signal-to-noise ratio for a unit.
- presence_ratio (float): Fraction of time where spikes are present.
- isi_violation (float): rate of ISI violation as a fraction of overall rate.
- number_violation (int): Total ISI violations.
- amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
- isolation_distance (float): Distance to nearest cluster.
- l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
- d_prime (float): Classification accuracy based on LDA.
- nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
- nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
- silhouette_core (float): Maximum change in spike depth throughout recording.
- cumulative_drift (float): Cumulative change in spike depth throughout recording.
- contamination_rate (float): Frequency of spikes in the refractory period.
- """
-
- definition = """
- # Cluster metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- firing_rate=null: float # (Hz) firing rate for a unit
- snr=null: float # signal-to-noise ratio for a unit
- presence_ratio=null: float # fraction of time in which spikes are present
- isi_violation=null: float # rate of ISI violation as a fraction of overall rate
- number_violation=null: int # total number of ISI violations
- amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
- isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
- d_prime=null: float # Classification accuracy based on LDA
- nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
- nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
- silhouette_score=null: float # Standard metric for cluster overlap
- max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
- """
-
- class Waveform(dj.Part):
- """Waveform metrics for a particular unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- amplitude (float): Absolute difference between waveform peak and trough in microvolts.
- duration (float): Time between waveform peak and trough in milliseconds.
- halfwidth (float): Spike width at half max amplitude.
- pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0.
- repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak.
- recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail.
- spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe.
- velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe.
- velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
- """
-
- definition = """
- # Waveform metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
- halfwidth=null: float # (ms) spike width at half max amplitude
- pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
- repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
- recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail
- spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe
- velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe
- velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe
- """
-
- def make(self, key):
- """Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
-
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
-
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
- metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
- metrics_list = [
- dict(metrics_df.loc[unit_key["unit"]], **unit_key)
- for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
- ]
-
- self.insert1(key)
- self.Cluster.insert(metrics_list, ignore_extra_fields=True)
- self.Waveform.insert(metrics_list, ignore_extra_fields=True)
-
-
-# ---------------- HELPER FUNCTIONS ----------------
-
-
-def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str:
- """Get spikeGLX data filepath."""
- # attempt to retrieve from EphysRecording.EphysFile
- spikeglx_meta_filepath = (
- EphysRecording.EphysFile & ephys_recording_key & 'file_path LIKE "%.ap.meta"'
- ).fetch1("file_path")
-
- try:
- spikeglx_meta_filepath = find_full_path(
- get_ephys_root_data_dir(), spikeglx_meta_filepath
- )
- except FileNotFoundError:
- # if not found, search in session_dir again
- if not spikeglx_meta_filepath.exists():
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
-
- spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")]
- for meta_filepath in spikeglx_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- spikeglx_meta_filepath = meta_filepath
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(
- ephys_recording_key
- )
- )
-
- return spikeglx_meta_filepath
-
-
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1("probe")
- probe_dataset = openephys_dataset.probes[probe_serial_number]
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_ids"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py
index 48bcf613..0c6836a0 100644
--- a/element_array_ephys/ephys_report.py
+++ b/element_array_ephys/ephys_report.py
@@ -2,31 +2,30 @@
import datetime
import pathlib
+import tempfile
from uuid import UUID
import datajoint as dj
from element_interface.utils import dict_to_uuid
-from . import probe
+from . import probe, ephys
schema = dj.schema()
-ephys = None
-
-def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True):
+def activate(schema_name, *, create_schema=True, create_tables=True):
"""Activate the current schema.
Args:
schema_name (str): schema name on the database server to activate the `ephys_report` schema.
- ephys_schema_name (str): schema name of the activated ephys element for which
- this ephys_report schema will be downstream from.
create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist.
create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist.
"""
+ if not probe.schema.is_activated():
+ raise RuntimeError("Please activate the `probe` schema first.")
+ if not ephys.schema.is_activated():
+ raise RuntimeError("Please activate the `ephys` schema first.")
- global ephys
- ephys = dj.create_virtual_module("ephys", ephys_schema_name)
schema.activate(
schema_name,
create_schema=create_schema,
@@ -55,7 +54,7 @@ class ProbeLevelReport(dj.Computed):
def make(self, key):
from .plotting.probe_level import plot_driftmap
- save_dir = _make_save_dir()
+ save_dir = tempfile.TemporaryDirectory()
units = ephys.CuratedClustering.Unit & key & "cluster_quality_label='good'"
@@ -90,13 +89,15 @@ def make(self, key):
fig_dict = _save_figs(
figs=(fig,),
fig_names=("drift_map_plot",),
- save_dir=save_dir,
+ save_dir=save_dir.name,
fig_prefix=fig_prefix,
extension=".png",
)
self.insert1({**key, **fig_dict, "shank": shank_no})
+ save_dir.cleanup()
+
@schema
class UnitLevelReport(dj.Computed):
@@ -268,17 +269,10 @@ def make(self, key):
)
-def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path:
- if root_dir is None:
- root_dir = pathlib.Path().absolute()
- save_dir = root_dir / "temp_ephys_figures"
- save_dir.mkdir(parents=True, exist_ok=True)
- return save_dir
-
-
def _save_figs(
figs, fig_names, save_dir, fig_prefix, extension=".png"
) -> dict[str, pathlib.Path]:
+ save_dir = pathlib.Path(save_dir)
fig_dict = {}
for fig, fig_name in zip(figs, fig_names):
fig_filepath = save_dir / (fig_prefix + "_" + fig_name + extension)
diff --git a/element_array_ephys/export/nwb/nwb.py b/element_array_ephys/export/nwb/nwb.py
index a45eb754..8d7da8f5 100644
--- a/element_array_ephys/export/nwb/nwb.py
+++ b/element_array_ephys/export/nwb/nwb.py
@@ -17,14 +17,7 @@
from spikeinterface import extractors
from tqdm import tqdm
-from ... import ephys_no_curation as ephys
-from ... import probe
-
-ephys_mode = os.getenv("EPHYS_MODE", dj.config["custom"].get("ephys_mode", "acute"))
-if ephys_mode != "no-curation":
- raise NotImplementedError(
- "This export function is designed for the no_curation " + "schema"
- )
+from ... import probe, ephys
class DecimalEncoder(json.JSONEncoder):
diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py
index 4b50619d..4f8530d8 100644
--- a/element_array_ephys/readers/kilosort.py
+++ b/element_array_ephys/readers/kilosort.py
@@ -1,12 +1,10 @@
-import logging
-import pathlib
-import re
-from datetime import datetime
from os import path
-
-import numpy as np
+from datetime import datetime
+import pathlib
import pandas as pd
-
+import numpy as np
+import re
+import logging
from .utils import convert_to_number
log = logging.getLogger(__name__)
@@ -117,7 +115,8 @@ def _load(self):
# Read the Cluster Groups
for cluster_pattern, cluster_col_name in zip(
- ["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"]
+ ["cluster_group.*", "cluster_KSLabel.*", "cluster_group.*"],
+ ["group", "KSLabel", "KSLabel"],
):
try:
cluster_file = next(self._kilosort_dir.glob(cluster_pattern))
@@ -126,22 +125,26 @@ def _load(self):
else:
cluster_file_suffix = cluster_file.suffix
assert cluster_file_suffix in (".tsv", ".xlsx")
- break
+
+ if cluster_file_suffix == ".tsv":
+ df = pd.read_csv(cluster_file, sep="\t", header=0)
+ elif cluster_file_suffix == ".xlsx":
+ df = pd.read_excel(cluster_file, engine="openpyxl")
+ else:
+ df = pd.read_csv(cluster_file, delimiter="\t")
+
+ try:
+ self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
+ self._data["cluster_ids"] = np.array(df["cluster_id"].values)
+ except KeyError:
+ continue
+ else:
+ break
else:
raise FileNotFoundError(
'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
)
- if cluster_file_suffix == ".tsv":
- df = pd.read_csv(cluster_file, sep="\t", header=0)
- elif cluster_file_suffix == ".xlsx":
- df = pd.read_excel(cluster_file, engine="openpyxl")
- else:
- df = pd.read_csv(cluster_file, delimiter="\t")
-
- self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
- self._data["cluster_ids"] = np.array(df["cluster_id"].values)
-
def get_best_channel(self, unit):
template_idx = self.data["spike_templates"][
np.where(self.data["spike_clusters"] == unit)[0][0]
diff --git a/element_array_ephys/readers/probe_geometry.py b/element_array_ephys/readers/probe_geometry.py
index b6fbc09e..f0d50a1c 100644
--- a/element_array_ephys/readers/probe_geometry.py
+++ b/element_array_ephys/readers/probe_geometry.py
@@ -140,8 +140,8 @@ def build_npx_probe(
return elec_pos_df
-def to_probeinterface(electrodes_df):
- from probeinterface import Probe
+def to_probeinterface(electrodes_df, **kwargs):
+ import probeinterface as pi
probe_df = electrodes_df.copy()
probe_df.rename(
@@ -153,10 +153,22 @@ def to_probeinterface(electrodes_df):
},
inplace=True,
)
- probe_df["contact_shapes"] = "square"
- probe_df["width"] = 12
-
- return Probe.from_dataframe(probe_df)
+ # Get the contact shapes. By default, it's set to circle with a radius of 10.
+ contact_shapes = kwargs.get("contact_shapes", "circle")
+ assert (
+ contact_shapes in pi.probe._possible_contact_shapes
+ ), f"contacts shape should be in {pi.probe._possible_contact_shapes}"
+
+ probe_df["contact_shapes"] = contact_shapes
+ if contact_shapes == "circle":
+ probe_df["radius"] = kwargs.get("radius", 10)
+ elif contact_shapes == "square":
+ probe_df["width"] = kwargs.get("width", 10)
+ elif contact_shapes == "rect":
+ probe_df["width"] = kwargs.get("width")
+ probe_df["height"] = kwargs.get("height")
+
+ return pi.Probe.from_dataframe(probe_df)
def build_electrode_layouts(
diff --git a/element_array_ephys/spike_sorting/__init__.py b/element_array_ephys/spike_sorting/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/spike_sorting/kilosort_triggering.py
similarity index 100%
rename from element_array_ephys/readers/kilosort_triggering.py
rename to element_array_ephys/spike_sorting/kilosort_triggering.py
diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py
new file mode 100644
index 00000000..22adbdca
--- /dev/null
+++ b/element_array_ephys/spike_sorting/si_preprocessing.py
@@ -0,0 +1,37 @@
+import spikeinterface as si
+from spikeinterface import preprocessing
+
+
+def CatGT(recording):
+ recording = si.preprocessing.phase_shift(recording)
+ recording = si.preprocessing.common_reference(
+ recording, operator="median", reference="global"
+ )
+ return recording
+
+
+def IBLdestriping(recording):
+ # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022.
+ recording = si.preprocessing.highpass_filter(recording, freq_min=400.0)
+ bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording)
+ # For IBL destriping interpolate bad channels
+ recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids)
+ recording = si.preprocessing.phase_shift(recording)
+ # For IBL destriping use highpass_spatial_filter used instead of common reference
+ recording = si.preprocessing.highpass_spatial_filter(
+ recording, operator="median", reference="global"
+ )
+ return recording
+
+
+def IBLdestriping_modified(recording):
+ # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html)
+ recording = si.preprocessing.highpass_filter(recording, freq_min=400.0)
+ bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording)
+ # For IBL destriping interpolate bad channels
+ recording = recording.remove_channels(bad_channel_ids)
+ recording = si.preprocessing.phase_shift(recording)
+ recording = si.preprocessing.common_reference(
+ recording, operator="median", reference="global"
+ )
+ return recording
diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py
new file mode 100644
index 00000000..e3e797b6
--- /dev/null
+++ b/element_array_ephys/spike_sorting/si_spike_sorting.py
@@ -0,0 +1,415 @@
+"""
+The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline.
+Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface)
+If you use this pipeline, please cite SpikeInterface and the relevant sorter(s) used in your publication (see https://github.com/SpikeInterface for additional details for citation).
+"""
+
+from datetime import datetime
+
+import datajoint as dj
+import pandas as pd
+import spikeinterface as si
+from element_array_ephys import probe, ephys, readers
+from element_interface.utils import find_full_path, memoized_result
+from spikeinterface import exporters, extractors, sorters
+
+from . import si_preprocessing
+
+log = dj.logger
+
+schema = dj.schema()
+
+
+def activate(
+ schema_name,
+ *,
+ create_schema=True,
+ create_tables=True,
+):
+ """Activate the current schema.
+
+ Args:
+ schema_name (str): schema name on the database server to activate the `si_spike_sorting` schema.
+ create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist.
+ create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist.
+ """
+ if not probe.schema.is_activated():
+ raise RuntimeError("Please activate the `probe` schema first.")
+ if not ephys.schema.is_activated():
+ raise RuntimeError("Please activate the `ephys` schema first.")
+
+ schema.activate(
+ schema_name,
+ create_schema=create_schema,
+ create_tables=create_tables,
+ add_objects=ephys.__dict__,
+ )
+ ephys.Clustering.key_source -= PreProcessing.key_source.proj()
+
+
+SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()]
+
+
+@schema
+class PreProcessing(dj.Imported):
+ """A table to handle preprocessing of each clustering task. The output will be serialized and stored as a si_recording.pkl in the output directory."""
+
+ definition = """
+ -> ephys.ClusteringTask
+ ---
+ execution_time: datetime # datetime of the start of this step
+ execution_duration: float # execution duration in hours
+ """
+
+ @property
+ def key_source(self):
+ return (
+ ephys.ClusteringTask * ephys.ClusteringParamSet
+ & {"task_mode": "trigger"}
+ & f"clustering_method in {tuple(SI_SORTERS)}"
+ ) - ephys.Clustering
+
+ def make(self, key):
+ """Triggers or imports clustering analysis."""
+ execution_time = datetime.utcnow()
+
+ # Set the output directory
+ clustering_method, acq_software, output_dir, params = (
+ ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params")
+
+ # Get sorter method and create output directory.
+ sorter_name = clustering_method.replace(".", "_")
+
+ for required_key in (
+ "SI_PREPROCESSING_METHOD",
+ "SI_SORTING_PARAMS",
+ "SI_POSTPROCESSING_PARAMS",
+ ):
+ if required_key not in params:
+ raise ValueError(
+ f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution"
+ )
+
+ # Set directory to store recording file.
+ if not output_dir:
+ output_dir = ephys.ClusteringTask.infer_output_dir(
+ key, relative=True, mkdir=True
+ )
+ # update clustering_output_dir
+ ephys.ClusteringTask.update1(
+ {**key, "clustering_output_dir": output_dir.as_posix()}
+ )
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ recording_dir = output_dir / sorter_name / "recording"
+ recording_dir.mkdir(parents=True, exist_ok=True)
+ recording_file = recording_dir / "si_recording.pkl"
+
+ # Create SI recording extractor object
+ if acq_software == "SpikeGLX":
+ spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key)
+ spikeglx_recording = readers.spikeglx.SpikeGLX(
+ spikeglx_meta_filepath.parent
+ )
+ spikeglx_recording.validate_file("ap")
+ data_dir = spikeglx_meta_filepath.parent
+
+ si_extractor = (
+ si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor
+ )
+ stream_names, stream_ids = si.extractors.get_neo_streams(
+ "spikeglx", folder_path=data_dir
+ )
+ si_recording: si.BaseRecording = si_extractor(
+ folder_path=data_dir, stream_name=stream_names[0]
+ )
+ elif acq_software == "Open Ephys":
+ oe_probe = ephys.get_openephys_probe_data(key)
+ assert len(oe_probe.recording_info["recording_files"]) == 1
+ data_dir = oe_probe.recording_info["recording_files"][0]
+ si_extractor = (
+ si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor
+ )
+
+ stream_names, stream_ids = si.extractors.get_neo_streams(
+ "openephysbinary", folder_path=data_dir
+ )
+ si_recording: si.BaseRecording = si_extractor(
+ folder_path=data_dir, stream_name=stream_names[0]
+ )
+ else:
+ raise NotImplementedError(
+ f"SpikeInterface processing for {acq_software} not yet implemented."
+ )
+
+ # Add probe information to recording object
+ electrodes_df = (
+ (
+ ephys.EphysRecording.Channel
+ * probe.ElectrodeConfig.Electrode
+ * probe.ProbeType.Electrode
+ & key
+ )
+ .fetch(format="frame")
+ .reset_index()
+ )
+
+ # Create SI probe object
+ si_probe = readers.probe_geometry.to_probeinterface(
+ electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]
+ )
+ si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
+ si_recording.set_probe(probe=si_probe, in_place=True)
+
+ # Run preprocessing and save results to output folder
+ si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"])
+ si_recording = si_preproc_func(si_recording)
+ si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir)
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ }
+ )
+
+
+@schema
+class SIClustering(dj.Imported):
+ """A processing table to handle each clustering task."""
+
+ definition = """
+ -> PreProcessing
+ ---
+ execution_time: datetime # datetime of the start of this step
+ execution_duration: float # execution duration in hours
+ """
+
+ def make(self, key):
+ execution_time = datetime.utcnow()
+
+ # Load recording object.
+ clustering_method, output_dir, params = (
+ ephys.ClusteringTask * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir", "params")
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+ recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
+ si_recording: si.BaseRecording = si.load_extractor(
+ recording_file, base_folder=output_dir
+ )
+
+ sorting_params = params["SI_SORTING_PARAMS"]
+ sorting_output_dir = output_dir / sorter_name / "spike_sorting"
+
+ # Run sorting
+ @memoized_result(
+ uniqueness_dict=sorting_params,
+ output_directory=sorting_output_dir,
+ )
+ def _run_sorter():
+ # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
+ si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
+ sorter_name=sorter_name,
+ recording=si_recording,
+ folder=sorting_output_dir,
+ remove_existing_folder=True,
+ verbose=True,
+ docker_image=sorter_name not in si.sorters.installed_sorters(),
+ **sorting_params,
+ )
+
+ # Save sorting object
+ sorting_save_path = sorting_output_dir / "si_sorting.pkl"
+ si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)
+
+ _run_sorter()
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ }
+ )
+
+
+@schema
+class PostProcessing(dj.Imported):
+ """A processing table to handle each clustering task."""
+
+ definition = """
+ -> SIClustering
+ ---
+ execution_time: datetime # datetime of the start of this step
+ execution_duration: float # execution duration in hours
+ do_si_export=0: bool # whether to export to phy
+ """
+
+ def make(self, key):
+ execution_time = datetime.utcnow()
+
+ # Load recording & sorting object.
+ clustering_method, output_dir, params = (
+ ephys.ClusteringTask * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir", "params")
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
+ sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
+
+ si_recording: si.BaseRecording = si.load_extractor(
+ recording_file, base_folder=output_dir
+ )
+ si_sorting: si.sorters.BaseSorter = si.load_extractor(
+ sorting_file, base_folder=output_dir
+ )
+
+ postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]
+
+ job_kwargs = postprocessing_params.get(
+ "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
+ )
+
+ analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
+
+ has_units = si_sorting.unit_ids.size > 0
+
+ @memoized_result(
+ uniqueness_dict=postprocessing_params,
+ output_directory=analyzer_output_dir,
+ )
+ def _sorting_analyzer_compute():
+ if not has_units:
+ log.info("No units found in sorting object. Skipping sorting analyzer.")
+ analyzer_output_dir.mkdir(
+ parents=True, exist_ok=True
+ ) # create empty directory anyway, for consistency
+ return
+
+ # Sorting Analyzer
+ sorting_analyzer = si.create_sorting_analyzer(
+ sorting=si_sorting,
+ recording=si_recording,
+ format="binary_folder",
+ folder=analyzer_output_dir,
+ sparse=True,
+ overwrite=True,
+ )
+
+ # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
+ # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
+ extensions_params = postprocessing_params.get("extensions", {})
+ extensions_to_compute = {
+ ext_name: extensions_params[ext_name]
+ for ext_name in sorting_analyzer.get_computable_extensions()
+ if ext_name in extensions_params
+ }
+
+ sorting_analyzer.compute(extensions_to_compute, **job_kwargs)
+
+ _sorting_analyzer_compute()
+
+ do_si_export = postprocessing_params.get(
+ "export_to_phy", False
+ ) or postprocessing_params.get("export_report", False)
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ "do_si_export": do_si_export and has_units,
+ }
+ )
+
+ # Once finished, insert this `key` into ephys.Clustering
+ ephys.Clustering.insert1(
+ {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True
+ )
+
+
+@schema
+class SIExport(dj.Computed):
+ """A SpikeInterface export report and to Phy"""
+
+ definition = """
+ -> PostProcessing
+ ---
+ execution_time: datetime
+ execution_duration: float
+ """
+
+ @property
+ def key_source(self):
+ return PostProcessing & "do_si_export = 1"
+
+ def make(self, key):
+ execution_time = datetime.utcnow()
+
+ clustering_method, output_dir, params = (
+ ephys.ClusteringTask * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir", "params")
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]
+
+ job_kwargs = postprocessing_params.get(
+ "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
+ )
+
+ analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
+ sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir)
+
+ @memoized_result(
+ uniqueness_dict=postprocessing_params,
+ output_directory=analyzer_output_dir / "phy",
+ )
+ def _export_to_phy():
+ # Save to phy format
+ si.exporters.export_to_phy(
+ sorting_analyzer=sorting_analyzer,
+ output_folder=analyzer_output_dir / "phy",
+ use_relative_path=True,
+ **job_kwargs,
+ )
+
+ @memoized_result(
+ uniqueness_dict=postprocessing_params,
+ output_directory=analyzer_output_dir / "spikeinterface_report",
+ )
+ def _export_report():
+ # Generate spike interface report
+ si.exporters.export_report(
+ sorting_analyzer=sorting_analyzer,
+ output_folder=analyzer_output_dir / "spikeinterface_report",
+ **job_kwargs,
+ )
+
+ if postprocessing_params.get("export_report", False):
+ _export_report()
+ if postprocessing_params.get("export_to_phy", False):
+ _export_to_phy()
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ }
+ )
diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py
index 39ba565b..2e6de55a 100644
--- a/element_array_ephys/version.py
+++ b/element_array_ephys/version.py
@@ -1,3 +1,3 @@
"""Package metadata."""
-__version__ = "0.3.8"
+__version__ = "0.4.0"
diff --git a/env.yml b/env.yml
new file mode 100644
index 00000000..e9b3ce13
--- /dev/null
+++ b/env.yml
@@ -0,0 +1,7 @@
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - pip
+ - python>=3.7,<3.11
+name: element_array_ephys
diff --git a/images/attached_array_ephys_element_no_curation.svg b/images/attached_array_ephys_element.svg
similarity index 100%
rename from images/attached_array_ephys_element_no_curation.svg
rename to images/attached_array_ephys_element.svg
diff --git a/images/attached_array_ephys_element_acute.svg b/images/attached_array_ephys_element_acute.svg
deleted file mode 100644
index 5b2bc265..00000000
--- a/images/attached_array_ephys_element_acute.svg
+++ /dev/null
@@ -1,451 +0,0 @@
-
\ No newline at end of file
diff --git a/images/attached_array_ephys_element_chronic.svg b/images/attached_array_ephys_element_chronic.svg
deleted file mode 100644
index 808a2f17..00000000
--- a/images/attached_array_ephys_element_chronic.svg
+++ /dev/null
@@ -1,456 +0,0 @@
-
\ No newline at end of file
diff --git a/images/attached_array_ephys_element_precluster.svg b/images/attached_array_ephys_element_precluster.svg
deleted file mode 100644
index 7d854d2e..00000000
--- a/images/attached_array_ephys_element_precluster.svg
+++ /dev/null
@@ -1,535 +0,0 @@
-
\ No newline at end of file
diff --git a/notebooks/demo_prepare.ipynb b/notebooks/demo_prepare.ipynb
deleted file mode 100644
index 74057ba4..00000000
--- a/notebooks/demo_prepare.ipynb
+++ /dev/null
@@ -1,225 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Demo Preparation Notebook\n",
- "\n",
- "**Please Note**: This notebook (`demo_prepare.ipynb`) and `demo_run.ipynb` are **NOT** intended to be used as learning materials. To gain\n",
- "a thorough understanding of the DataJoint workflow for extracellular electrophysiology, please\n",
- "see the [`tutorial`](./tutorial.ipynb) notebook."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Runs in about 45s\n",
- "import datajoint as dj\n",
- "import datetime\n",
- "from tutorial_pipeline import subject, session, probe, ephys\n",
- "from element_array_ephys import ephys_report"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "subject.Subject.insert1(\n",
- " dict(subject=\"subject5\", subject_birth_date=\"2023-01-01\", sex=\"U\")\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "session_key = dict(subject=\"subject5\", session_datetime=\"2023-01-01 00:00:00\")\n",
- "\n",
- "session.Session.insert1(session_key)\n",
- "\n",
- "session.SessionDirectory.insert1(dict(session_key, session_dir=\"raw/subject5/session1\"))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "probe.Probe.insert1(dict(probe=\"714000838\", probe_type=\"neuropixels 1.0 - 3B\"))\n",
- "\n",
- "ephys.ProbeInsertion.insert1(\n",
- " dict(\n",
- " session_key,\n",
- " insertion_number=1,\n",
- " probe=\"714000838\",\n",
- " )\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "populate_settings = {\"display_progress\": True}\n",
- "\n",
- "ephys.EphysRecording.populate(**populate_settings)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "kilosort_params = {\n",
- " \"fs\": 30000,\n",
- " \"fshigh\": 150,\n",
- " \"minfr_goodchannels\": 0.1,\n",
- " \"Th\": [10, 4],\n",
- " \"lam\": 10,\n",
- " \"AUCsplit\": 0.9,\n",
- " \"minFR\": 0.02,\n",
- " \"momentum\": [20, 400],\n",
- " \"sigmaMask\": 30,\n",
- " \"ThPr\": 8,\n",
- " \"spkTh\": -6,\n",
- " \"reorder\": 1,\n",
- " \"nskip\": 25,\n",
- " \"GPU\": 1,\n",
- " \"Nfilt\": 1024,\n",
- " \"nfilt_factor\": 4,\n",
- " \"ntbuff\": 64,\n",
- " \"whiteningRange\": 32,\n",
- " \"nSkipCov\": 25,\n",
- " \"scaleproc\": 200,\n",
- " \"nPCs\": 3,\n",
- " \"useRAM\": 0,\n",
- "}\n",
- "\n",
- "ephys.ClusteringParamSet.insert_new_params(\n",
- " clustering_method=\"kilosort2\",\n",
- " paramset_idx=1,\n",
- " params=kilosort_params,\n",
- " paramset_desc=\"Spike sorting using Kilosort2\",\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ephys.ClusteringTask.insert1(\n",
- " dict(\n",
- " session_key,\n",
- " insertion_number=1,\n",
- " paramset_idx=1,\n",
- " task_mode=\"load\", # load or trigger\n",
- " clustering_output_dir=\"processed/subject5/session1/probe_1/kilosort2-5_1\",\n",
- " )\n",
- ")\n",
- "\n",
- "ephys.Clustering.populate(**populate_settings)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "clustering_key = (ephys.ClusteringTask & session_key).fetch1(\"KEY\")\n",
- "ephys.Curation().create1_from_clustering_task(clustering_key)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Runs in about 12m\n",
- "ephys.CuratedClustering.populate(**populate_settings)\n",
- "ephys.WaveformSet.populate(**populate_settings)\n",
- "ephys_report.ProbeLevelReport.populate(**populate_settings)\n",
- "ephys_report.UnitLevelReport.populate(**populate_settings)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Drop schemas\n",
- "- Schemas are not typically dropped in a production workflow with real data in it.\n",
- "- At the developmental phase, it might be required for the table redesign.\n",
- "- When dropping all schemas is needed, the following is the dependency order."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def drop_databases(databases):\n",
- " import pymysql.err\n",
- "\n",
- " conn = dj.conn()\n",
- "\n",
- " with dj.config(safemode=False):\n",
- " for database in databases:\n",
- " schema = dj.Schema(f'{dj.config[\"custom\"][\"database.prefix\"]}{database}')\n",
- " while schema.list_tables():\n",
- " for table in schema.list_tables():\n",
- " try:\n",
- " conn.query(f\"DROP TABLE `{schema.database}`.`{table}`\")\n",
- " except pymysql.err.OperationalError:\n",
- " print(f\"Can't drop `{schema.database}`.`{table}`. Retrying...\")\n",
- " schema.drop()\n",
- "\n",
- "\n",
- "# drop_databases(databases=['analysis', 'trial', 'event', 'ephys_report', 'ephys', 'probe', 'session', 'subject', 'project', 'lab'])"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.17"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/notebooks/demo_run.ipynb b/notebooks/demo_run.ipynb
deleted file mode 100644
index 348a3c43..00000000
--- a/notebooks/demo_run.ipynb
+++ /dev/null
@@ -1,108 +0,0 @@
-{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# DataJoint Workflow for Neuropixels Analysis\n",
- "\n",
- "+ This notebook demonstrates using the open-source DataJoint Element to build a workflow for extracellular electrophysiology.\n",
- "+ For a detailed tutorial, please see the [tutorial notebook](./tutorial.ipynb)."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Import dependencies"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import datajoint as dj\n",
- "from tutorial_pipeline import subject, session, probe, ephys\n",
- "from element_array_ephys.plotting.widget import main"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### View workflow"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "(\n",
- " dj.Diagram(subject.Subject)\n",
- " + dj.Diagram(session.Session)\n",
- " + dj.Diagram(probe)\n",
- " + dj.Diagram(ephys)\n",
- ")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Visualize processed data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "main(ephys)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For an in-depth tutorial please see the [tutorial notebook](./tutorial.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "python3p10",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.17"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "ff52d424e56dd643d8b2ec122f40a2e279e94970100b4e6430cb9025a65ba4cf"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/setup.py b/setup.py
index 19a4a5ae..38b08c29 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
from os import path
-from setuptools import find_packages, setup
+from setuptools import find_packages, setup
pkg_name = "element_array_ephys"
here = path.abspath(path.dirname(__file__))
@@ -16,6 +16,7 @@
setup(
name=pkg_name.replace("_", "-"),
+ python_requires=">=3.7, <3.11",
version=__version__, # noqa F821
description="Extracellular Array Electrophysiology DataJoint Element",
long_description=long_description,
@@ -34,20 +35,22 @@
"openpyxl",
"plotly",
"seaborn",
- "spikeinterface",
+ "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git",
"scikit-image>=0.20",
"nbformat>=4.2.0",
"pyopenephys>=1.1.6",
+ "element-interface @ git+https://github.com/datajoint/element-interface.git",
+ "numba",
],
extras_require={
"elements": [
"element-animal @ git+https://github.com/datajoint/element-animal.git",
"element-event @ git+https://github.com/datajoint/element-event.git",
- "element-interface @ git+https://github.com/datajoint/element-interface.git",
"element-lab @ git+https://github.com/datajoint/element-lab.git",
"element-session @ git+https://github.com/datajoint/element-session.git",
],
"nwb": ["dandi", "neuroconv[ecephys]", "pynwb"],
"tests": ["pre-commit", "pytest", "pytest-cov"],
+ "spikingcircus": ["hdbscan"],
},
)
diff --git a/tests/tutorial_pipeline.py b/tests/tutorial_pipeline.py
index 74b27ddc..1b27027d 100644
--- a/tests/tutorial_pipeline.py
+++ b/tests/tutorial_pipeline.py
@@ -3,7 +3,7 @@
import datajoint as dj
from element_animal import subject
from element_animal.subject import Subject
-from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report
+from element_array_ephys import probe, ephys, ephys_report
from element_lab import lab
from element_lab.lab import Lab, Location, Project, Protocol, Source, User
from element_lab.lab import Device as Equipment
@@ -62,7 +62,9 @@ def get_session_directory(session_key):
return pathlib.Path(session_directory)
-ephys.activate(db_prefix + "ephys", db_prefix + "probe", linking_module=__name__)
+probe.activate(db_prefix + "probe")
+ephys.activate(db_prefix + "ephys", linking_module=__name__)
+ephys_report.activate(db_prefix + "ephys_report")
probe.create_neuropixels_probe_types()