Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/LorenFrankLab/spyglass in…
Browse files Browse the repository at this point in the history
…to dst
  • Loading branch information
CBroz1 committed Sep 4, 2024
2 parents 869bcd8 + d4dbc23 commit 10b96de
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 101 deletions.
30 changes: 17 additions & 13 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# Change Log

## [0.5.3] (Unreleased)
## [0.5.4] (Unreleased)

## Release Notes
### Release Notes

<!-- Running draft to be removed immediately prior to release. -->

```python
import datajoint as dj
from spyglass.common.common_behav import PositionIntervalMap
from spyglass.decoding.v1.core import PositionGroup
### Infrastructure

- Disable populate transaction protection for long-populating tables #1066
- Add docstrings to all public methods #1076

### Pipelines

- Decoding
- Fix edge case errors in spike time loading #1083

dj.schema("common_ripple").drop()
PositionIntervalMap.alter()
PositionGroup.alter()
```
## [0.5.3] (August 27, 2024)

### Infrastructure

Expand All @@ -39,14 +41,15 @@ PositionGroup.alter()
- Allow `ModuleNotFoundError` or `ImportError` for optional dependencies #1023
- Ensure integrity of group tables #1026
- Convert list of LFP artifact removed interval list to array #1046
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1058,
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1062,
#1066, #1069
- Revise docs organization.
- Reivise docs organization.
- Misc -> Features/ForDevelopers. #1029
- Installation instructions -> Setup notebook. #1029
- Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048
- Add tool for checking threads for metadata locks on a table #1063
- Add docstrings to all public methods #1076
- Use peripheral tables as fallback in `TableChains` #1035
- Ignore non-Spyglass tables during descendant check for `part_masters` #1035

### Pipelines

Expand Down Expand Up @@ -333,3 +336,4 @@ PositionGroup.alter()
[0.5.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.1
[0.5.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.2
[0.5.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.3
[0.5.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.4
30 changes: 30 additions & 0 deletions docs/src/Features/Mixin.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,33 @@ nwbfile = Nwbfile()
(nwbfile & "nwb_file_name LIKE 'Name%'").ddp(dry_run=False)
(nwbfile & "nwb_file_name LIKE 'Other%'").ddp(dry_run=False)
```

## Populate Calls

The mixin also overrides the default `populate` function to provide additional
functionality for non-daemon process pools and disabling transaction protection.

### Non-Daemon Process Pools

To allow the `make` function to spawn a new process pool, the mixin overrides
the default `populate` function for tables with `_parallel_make` set to `True`.
See [issue #1000](https://github.com/LorenFrankLab/spyglass/issues/1000) and
[PR #1001](https://github.com/LorenFrankLab/spyglass/pull/1001) for more
information.

### Disable Transaction Protection

By default, DataJoint wraps the `populate` function in a transaction to ensure
data integrity (see
[Transactions](https://docs.datajoint.io/python/definition/05-Transactions.html)).

This can cause issues when populating large tables if another user attempts to
declare/modify a table while the transaction is open (see
[issue #1030](https://github.com/LorenFrankLab/spyglass/issues/1030) and
[DataJoint issue #1170](https://github.com/datajoint/datajoint-python/issues/1170)).

Tables with `_use_transaction` set to `False` will not be wrapped in a
transaction when calling `populate`. Transaction protection is replaced by a
hash of upstream data to ensure no changes are made to the table during the
unprotected populate. The additional time required to hash the data is a
trade-off for already time-consuming populates, but avoids blocking other users.
11 changes: 8 additions & 3 deletions src/spyglass/decoding/v1/waveform_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,12 @@ def make(self, key):
sorter,
)

spike_times = SpikeSortingOutput().fetch_nwb(merge_key)[0][
analysis_nwb_key
]["spike_times"]
nwb = SpikeSortingOutput().fetch_nwb(merge_key)[0]
spike_times = (
nwb[analysis_nwb_key]["spike_times"]
if analysis_nwb_key in nwb
else pd.DataFrame()
)

(
key["analysis_file_name"],
Expand Down Expand Up @@ -351,6 +354,8 @@ def _write_waveform_features_to_nwb(
metric_dict[unit_id] if unit_id in metric_dict else []
for unit_id in unit_ids
]
if not metric_values:
metric_values = np.array([]).astype(np.float32)
nwbf.add_unit_column(
name=metric,
description=metric,
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/position/v1/position_dlc_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ class DLCModelTraining(SpyglassMixin, dj.Computed):
latest_snapshot: int unsigned # latest exact snapshot index (i.e., never -1)
config_template: longblob # stored full config file
"""

log_path = None
_use_transaction, _allow_insert = False, True

# To continue from previous training snapshot,
# devs suggest editing pose_cfg.yml
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/spikesorting/spikesorting_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.spikesorting.imported import ImportedSpikeSorting # noqa: F401
from spyglass.spikesorting.v0.spikesorting_curation import (
from spyglass.spikesorting.v0.spikesorting_curation import ( # noqa: F401
CuratedSpikeSorting,
) # noqa: F401
)
from spyglass.spikesorting.v1 import ArtifactDetectionSelection # noqa: F401
from spyglass.spikesorting.v1 import (
CurationV1,
Expand Down Expand Up @@ -211,7 +211,7 @@ def get_spike_indicator(cls, key, time):
"""
time = np.asarray(time)
min_time, max_time = time[[0, -1]]
spike_times = cls.fetch_spike_data(key) # CB: This is undefined.
spike_times = (cls & key).get_spike_times(key)
spike_indicator = np.zeros((len(time), len(spike_times)))

for ind, times in enumerate(spike_times):
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/figurl_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class FigURLCuration(SpyglassMixin, dj.Computed):
url: varchar(1000)
"""

_use_transaction, _allow_insert = False, True

def make(self, key: dict):
"""Generate a FigURL for manual curation of a spike sorting."""
# FETCH
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class MetricCuration(SpyglassMixin, dj.Computed):
object_id: varchar(40) # Object ID for the metrics in NWB file
"""

_use_transaction, _allow_insert = False, True

def make(self, key):
"""Populate MetricCuration table.
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class SpikeSorting(SpyglassMixin, dj.Computed):
time_of_sort: int # in Unix time, to the nearest second
"""

_use_transaction, _allow_insert = False, True

def make(self, key: dict):
"""Runs spike sorting on the data and parameters specified by the
SpikeSortingSelection table and inserts a new entry to SpikeSorting table.
Expand Down
52 changes: 43 additions & 9 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy
from enum import Enum
from functools import cached_property
from hashlib import md5 as hash_md5
from itertools import chain as iter_chain
from typing import Any, Dict, Iterable, List, Set, Tuple, Union

Expand Down Expand Up @@ -248,7 +249,7 @@ def _get_ft(self, table, with_restr=False, warn=True):

return ft & restr

def _is_out(self, table, warn=True):
def _is_out(self, table, warn=True, keep_alias=False):
"""Check if table is outside of spyglass."""
table = ensure_names(table)
if self.graph.nodes.get(table):
Expand Down Expand Up @@ -595,6 +596,14 @@ def leaf_ft(self):
"""Get restricted FreeTables from graph leaves."""
return [self._get_ft(table, with_restr=True) for table in self.leaves]

@property
def hash(self):
"""Return hash of all visited nodes."""
initial = hash_md5(b"")
for table in self.all_ft:
initial.update(table.fetch())
return initial.hexdigest()

# ------------------------------- Add Nodes -------------------------------

def add_leaf(
Expand Down Expand Up @@ -805,7 +814,8 @@ class TableChain(RestrGraph):
Returns path OrderedDict of full table names in chain. If directed is
True, uses directed graph. If False, uses undirected graph. Undirected
excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain
valid joins.
valid joins by default. If no path is found, another search is attempted
with PERIPHERAL_TABLES included.
cascade(restriction: str = None, direction: str = "up")
Given a restriction at the beginning, return a restricted FreeTable
object at the end of the chain. If direction is 'up', start at the child
Expand Down Expand Up @@ -835,8 +845,12 @@ def __init__(
super().__init__(seed_table=seed_table, verbose=verbose)

self._ignore_peripheral(except_tables=[self.parent, self.child])
self._ignore_outside_spy(except_tables=[self.parent, self.child])

self.no_visit.update(ensure_names(banned_tables) or [])

self.no_visit.difference_update(set([self.parent, self.child]))

self.searched_tables = set()
self.found_restr = False
self.link_type = None
Expand Down Expand Up @@ -872,7 +886,19 @@ def _ignore_peripheral(self, except_tables: List[str] = None):
except_tables = ensure_names(except_tables)
ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or [])
self.no_visit.update(ignore_tables)
self.undirect_graph.remove_nodes_from(ignore_tables)

def _ignore_outside_spy(self, except_tables: List[str] = None):
"""Ignore tables not shared on shared prefixes."""
except_tables = ensure_names(except_tables)
ignore_tables = set( # Ignore tables not in shared modules
[
t
for t in self.undirect_graph.nodes
if t not in except_tables
and self._is_out(t, warn=False, keep_alias=True)
]
)
self.no_visit.update(ignore_tables)

# --------------------------- Dunder Properties ---------------------------

Expand Down Expand Up @@ -1069,9 +1095,9 @@ def find_path(self, directed=True) -> List[str]:
List of names in the path.
"""
source, target = self.parent, self.child
search_graph = self.graph if directed else self.undirect_graph

search_graph.remove_nodes_from(self.no_visit)
search_graph = ( # Copy to ensure orig not modified by no_visit
self.graph.copy() if directed else self.undirect_graph.copy()
)

try:
path = shortest_path(search_graph, source, target)
Expand Down Expand Up @@ -1099,6 +1125,12 @@ def path(self) -> list:
self.link_type = "directed"
elif path := self.find_path(directed=False):
self.link_type = "undirected"
else: # Search with peripheral
self.no_visit.difference_update(PERIPHERAL_TABLES)
if path := self.find_path(directed=True):
self.link_type = "directed with peripheral"
elif path := self.find_path(directed=False):
self.link_type = "undirected with peripheral"
self.searched_path = True

return path
Expand Down Expand Up @@ -1130,9 +1162,11 @@ def cascade(
# Cascade will stop if any restriction is empty, so set rest to None
# This would cause issues if we want a table partway through the chain
# but that's not a typical use case, were the start and end are desired
non_numeric = [t for t in self.path if not t.isnumeric()]
if any(self._get_restr(t) is None for t in non_numeric):
for table in non_numeric:
safe_tbls = [
t for t in self.path if not t.isnumeric() and not self._is_out(t)
]
if any(self._get_restr(t) is None for t in safe_tbls):
for table in safe_tbls:
if table is not start:
self._set_restr(table, False, replace=True)

Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def make_file_obj_id_unique(nwb_path: str):
def populate_pass_function(value):
"""Pass function for parallel populate.
Note: To avoid pickling errors, the table must be passed by class, NOT by instance.
Note: To avoid pickling errors, the table must be passed by class,
NOT by instance.
Note: This function must be defined in the global namespace.
Parameters
Expand Down
Loading

0 comments on commit 10b96de

Please sign in to comment.