From ecf468e2c7ff77652d6b646f7f07d52ed19e84e3 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Tue, 27 Aug 2024 11:06:59 -0500 Subject: [PATCH 1/3] Periph table fallback on TableChain for experimenter summary (#1035) * Periph table fallback on TableChain * Update Changelog * Rely on search to remove no_visit, not id step * Include generic load_shared_schemas * Update changelog for release * Allow add custom prefix for load schemas * Fix merge error --- CHANGELOG.md | 18 +---- src/spyglass/utils/dj_graph.py | 43 +++++++++--- src/spyglass/utils/dj_mixin.py | 117 ++++++++++++++++----------------- 3 files changed, 95 insertions(+), 83 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5d467ec1..57fd495d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,20 +1,6 @@ # Change Log -## [0.5.3] (Unreleased) - -## Release Notes - - - -```python -import datajoint as dj -from spyglass.common.common_behav import PositionIntervalMap -from spyglass.decoding.v1.core import PositionGroup - -dj.schema("common_ripple").drop() -PositionIntervalMap.alter() -PositionGroup.alter() -``` +## [0.5.3] (August 27, 2024) ### Infrastructure @@ -46,6 +32,8 @@ PositionGroup.alter() - Installation instructions -> Setup notebook. #1029 - Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048 - Add tool for checking threads for metadata locks on a table #1063 +- Use peripheral tables as fallback in `TableChains` #1035 +- Ignore non-Spyglass tables during descendant check for `part_masters` #1035 ### Pipelines diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 0ab4ab477..6b3928042 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -248,7 +248,7 @@ def _get_ft(self, table, with_restr=False, warn=True): return ft & restr - def _is_out(self, table, warn=True): + def _is_out(self, table, warn=True, keep_alias=False): """Check if table is outside of spyglass.""" table = ensure_names(table) if self.graph.nodes.get(table): @@ -805,7 +805,8 @@ class TableChain(RestrGraph): Returns path OrderedDict of full table names in chain. If directed is True, uses directed graph. If False, uses undirected graph. Undirected excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain - valid joins. + valid joins by default. If no path is found, another search is attempted + with PERIPHERAL_TABLES included. cascade(restriction: str = None, direction: str = "up") Given a restriction at the beginning, return a restricted FreeTable object at the end of the chain. If direction is 'up', start at the child @@ -835,8 +836,12 @@ def __init__( super().__init__(seed_table=seed_table, verbose=verbose) self._ignore_peripheral(except_tables=[self.parent, self.child]) + self._ignore_outside_spy(except_tables=[self.parent, self.child]) + self.no_visit.update(ensure_names(banned_tables) or []) + self.no_visit.difference_update(set([self.parent, self.child])) + self.searched_tables = set() self.found_restr = False self.link_type = None @@ -872,7 +877,19 @@ def _ignore_peripheral(self, except_tables: List[str] = None): except_tables = ensure_names(except_tables) ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or []) self.no_visit.update(ignore_tables) - self.undirect_graph.remove_nodes_from(ignore_tables) + + def _ignore_outside_spy(self, except_tables: List[str] = None): + """Ignore tables not shared on shared prefixes.""" + except_tables = ensure_names(except_tables) + ignore_tables = set( # Ignore tables not in shared modules + [ + t + for t in self.undirect_graph.nodes + if t not in except_tables + and self._is_out(t, warn=False, keep_alias=True) + ] + ) + self.no_visit.update(ignore_tables) # --------------------------- Dunder Properties --------------------------- @@ -1066,9 +1083,9 @@ def find_path(self, directed=True) -> List[str]: List of names in the path. """ source, target = self.parent, self.child - search_graph = self.graph if directed else self.undirect_graph - - search_graph.remove_nodes_from(self.no_visit) + search_graph = ( # Copy to ensure orig not modified by no_visit + self.graph.copy() if directed else self.undirect_graph.copy() + ) try: path = shortest_path(search_graph, source, target) @@ -1096,6 +1113,12 @@ def path(self) -> list: self.link_type = "directed" elif path := self.find_path(directed=False): self.link_type = "undirected" + else: # Search with peripheral + self.no_visit.difference_update(PERIPHERAL_TABLES) + if path := self.find_path(directed=True): + self.link_type = "directed with peripheral" + elif path := self.find_path(directed=False): + self.link_type = "undirected with peripheral" self.searched_path = True return path @@ -1126,9 +1149,11 @@ def cascade( # Cascade will stop if any restriction is empty, so set rest to None # This would cause issues if we want a table partway through the chain # but that's not a typical use case, were the start and end are desired - non_numeric = [t for t in self.path if not t.isnumeric()] - if any(self._get_restr(t) is None for t in non_numeric): - for table in non_numeric: + safe_tbls = [ + t for t in self.path if not t.isnumeric() and not self._is_out(t) + ] + if any(self._get_restr(t) is None for t in safe_tbls): + for table in safe_tbls: if table is not start: self._set_restr(table, False, replace=True) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index ff3922087..04b873740 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -261,52 +261,41 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_parts ------------------------ - def _import_part_masters(self): - """Import tables that may constrain a RestrGraph. See #1002""" - from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 - from spyglass.decoding.v0.clusterless import ( - UnitMarksIndicatorSelection, - ) # noqa F401 - from spyglass.decoding.v0.sorted_spikes import ( - SortedSpikesIndicatorSelection, - ) # noqa F401 - from spyglass.decoding.v1.core import PositionGroup # noqa F401 - from spyglass.lfp.analysis.v1 import LFPBandSelection # noqa F401 - from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 - from spyglass.linearization.merge import ( # noqa F401 - LinearizedPositionOutput, - LinearizedPositionV1, - ) - from spyglass.mua.v1.mua import MuaEventsV1 # noqa F401 - from spyglass.position.position_merge import PositionOutput # noqa F401 - from spyglass.ripple.v1.ripple import RippleTimesV1 # noqa F401 - from spyglass.spikesorting.analysis.v1.group import ( - SortedSpikesGroup, - ) # noqa F401 - from spyglass.spikesorting.spikesorting_merge import ( - SpikeSortingOutput, - ) # noqa F401 - from spyglass.spikesorting.v0.figurl_views import ( - SpikeSortingRecordingView, - ) # noqa F401 - - _ = ( - DecodingOutput(), - LFPBandSelection(), - LFPOutput(), - LinearizedPositionOutput(), - LinearizedPositionV1(), - MuaEventsV1(), - PositionGroup(), - PositionOutput(), - RippleTimesV1(), - SortedSpikesGroup(), - SortedSpikesIndicatorSelection(), - SpikeSortingOutput(), - SpikeSortingRecordingView(), - UnitMarksIndicatorSelection(), + def load_shared_schemas(self, additional_prefixes: list = None) -> None: + """Load shared schemas to include in graph traversal. + + Parameters + ---------- + additional_prefixes : list, optional + Additional prefixes to load. Default None. + """ + all_shared = [ + *SHARED_MODULES, + dj.config["database.user"], + "file", + "sharing", + ] + + if additional_prefixes: + all_shared.extend(additional_prefixes) + + # Get a list of all shared schemas in spyglass + schemas = dj.conn().query( + "SELECT DISTINCT table_schema " # Unique schemas + + "FROM information_schema.key_column_usage " + + "WHERE" + + ' table_name not LIKE "~%%"' # Exclude hidden + + " AND constraint_name='PRIMARY'" # Only primary keys + + "AND (" # Only shared schemas + + " OR ".join([f"table_schema LIKE '{s}_%%'" for s in all_shared]) + + ") " + + "ORDER BY table_schema;" ) + # Load the dependencies for all shared schemas + for schema in schemas: + dj.schema(schema[0]).connection.dependencies.load() + @cached_property def _part_masters(self) -> set: """Set of master tables downstream of self. @@ -318,23 +307,25 @@ def _part_masters(self) -> set: part_masters = set() def search_descendants(parent): - for desc in parent.descendants(as_objects=True): + for desc_name in parent.descendants(): if ( # Check if has master, is part - not (master := get_master(desc.full_table_name)) - # has other non-master parent - or not set(desc.parents()) - set([master]) + not (master := get_master(desc_name)) or master in part_masters # already in cache + or desc_name.replace("`", "").split("_")[0] + not in SHARED_MODULES ): continue - if master not in part_masters: - part_masters.add(master) - search_descendants(dj.FreeTable(self.connection, master)) + desc = dj.FreeTable(self.connection, desc_name) + if not set(desc.parents()) - set([master]): # no other parent + continue + part_masters.add(master) + search_descendants(dj.FreeTable(self.connection, master)) try: _ = search_descendants(self) except NetworkXError: - try: # Attempt to import missing table - self._import_part_masters() + try: # Attempt to import failing schema + self.load_shared_schemas() _ = search_descendants(self) except NetworkXError as e: table_name = "".join(e.args[0].split("`")[1:4]) @@ -484,7 +475,7 @@ def _delete_deps(self) -> List[Table]: self._member_pk = LabMember.primary_key[0] return [LabMember, LabTeam, Session, schema.external, IntervalList] - def _get_exp_summary(self): + def _get_exp_summary(self) -> Union[QueryExpression, None]: """Get summary of experimenters for session(s), including NULL. Parameters @@ -494,9 +485,12 @@ def _get_exp_summary(self): Returns ------- - str - Summary of experimenters for session(s). + Union[QueryExpression, None] + dj.Union object Summary of experimenters for session(s). If no link + to Session, return None. """ + if not self._session_connection.has_link: + return None Session = self._delete_deps[2] SesExp = Session.Experimenter @@ -521,8 +515,7 @@ def _session_connection(self): """Path from Session table to self. False if no connection found.""" from spyglass.utils.dj_graph import TableChain # noqa F401 - connection = TableChain(parent=self._delete_deps[2], child=self) - return connection if connection.has_link else False + return TableChain(parent=self._delete_deps[2], child=self, verbose=True) @cached_property def _test_mode(self) -> bool: @@ -564,7 +557,13 @@ def _check_delete_permission(self) -> None: ) return - sess_summary = self._get_exp_summary() + if not (sess_summary := self._get_exp_summary()): + logger.warn( + f"Could not find a connection from {self.camel_name} " + + "to Session.\n Be careful not to delete others' data." + ) + return + experimenters = sess_summary.fetch(self._member_pk) if None in experimenters: raise PermissionError( From adfed75f60ffd0a770369d87c27043a7e8c306f8 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Thu, 29 Aug 2024 11:24:11 -0500 Subject: [PATCH 2/3] Allow disable transaction for select populates (#1067) * Allow disable transaction for select populates * WIP: hash for data integrity * WIP: hash for data integrity 2 * WIP: hash for data integrity 3 * Add docs * Delete on hash mismatch * Incorporate feedback --- CHANGELOG.md | 17 ++- docs/src/Features/Mixin.md | 30 +++++ .../position/v1/position_dlc_training.py | 2 + .../spikesorting/v1/figurl_curation.py | 2 + .../spikesorting/v1/metric_curation.py | 2 + src/spyglass/spikesorting/v1/sorting.py | 2 + src/spyglass/utils/dj_graph.py | 9 ++ src/spyglass/utils/dj_helper_fn.py | 3 +- src/spyglass/utils/dj_mixin.py | 118 +++++++++++++++--- 9 files changed, 166 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57fd495d2..23311bf14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Change Log +## [0.5.4] (Unreleased) + +### Release Notes + + + +### Infrastructure + +- Disable populate transaction protection for long-populating tables #1066 + ## [0.5.3] (August 27, 2024) ### Infrastructure @@ -25,9 +35,9 @@ - Allow `ModuleNotFoundError` or `ImportError` for optional dependencies #1023 - Ensure integrity of group tables #1026 - Convert list of LFP artifact removed interval list to array #1046 -- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1058, - #1066 -- Revise docs organization. +- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1062, + #1066, #1069 +- Reivise docs organization. - Misc -> Features/ForDevelopers. #1029 - Installation instructions -> Setup notebook. #1029 - Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048 @@ -320,3 +330,4 @@ [0.5.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.1 [0.5.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.2 [0.5.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.3 +[0.5.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.4 diff --git a/docs/src/Features/Mixin.md b/docs/src/Features/Mixin.md index ac227a7be..bc02087ce 100644 --- a/docs/src/Features/Mixin.md +++ b/docs/src/Features/Mixin.md @@ -188,3 +188,33 @@ nwbfile = Nwbfile() (nwbfile & "nwb_file_name LIKE 'Name%'").ddp(dry_run=False) (nwbfile & "nwb_file_name LIKE 'Other%'").ddp(dry_run=False) ``` + +## Populate Calls + +The mixin also overrides the default `populate` function to provide additional +functionality for non-daemon process pools and disabling transaction protection. + +### Non-Daemon Process Pools + +To allow the `make` function to spawn a new process pool, the mixin overrides +the default `populate` function for tables with `_parallel_make` set to `True`. +See [issue #1000](https://github.com/LorenFrankLab/spyglass/issues/1000) and +[PR #1001](https://github.com/LorenFrankLab/spyglass/pull/1001) for more +information. + +### Disable Transaction Protection + +By default, DataJoint wraps the `populate` function in a transaction to ensure +data integrity (see +[Transactions](https://docs.datajoint.io/python/definition/05-Transactions.html)). + +This can cause issues when populating large tables if another user attempts to +declare/modify a table while the transaction is open (see +[issue #1030](https://github.com/LorenFrankLab/spyglass/issues/1030) and +[DataJoint issue #1170](https://github.com/datajoint/datajoint-python/issues/1170)). + +Tables with `_use_transaction` set to `False` will not be wrapped in a +transaction when calling `populate`. Transaction protection is replaced by a +hash of upstream data to ensure no changes are made to the table during the +unprotected populate. The additional time required to hash the data is a +trade-off for already time-consuming populates, but avoids blocking other users. diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 94548c6b1..85e86b1c0 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -102,7 +102,9 @@ class DLCModelTraining(SpyglassMixin, dj.Computed): latest_snapshot: int unsigned # latest exact snapshot index (i.e., never -1) config_template: longblob # stored full config file """ + log_path = None + _use_transaction, _allow_insert = False, True # To continue from previous training snapshot, # devs suggest editing pose_cfg.yml diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py index fca4fb26b..03b0313c7 100644 --- a/src/spyglass/spikesorting/v1/figurl_curation.py +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -117,6 +117,8 @@ class FigURLCuration(SpyglassMixin, dj.Computed): url: varchar(1000) """ + _use_transaction, _allow_insert = False, True + def make(self, key: dict): # FETCH query = ( diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index b03c7fa9c..6ef520947 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -203,6 +203,8 @@ class MetricCuration(SpyglassMixin, dj.Computed): object_id: varchar(40) # Object ID for the metrics in NWB file """ + _use_transaction, _allow_insert = False, True + def make(self, key): AnalysisNwbfile()._creation_times["pre_create_time"] = time() # FETCH diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index 9196eb627..47e8b6b68 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -144,6 +144,8 @@ class SpikeSorting(SpyglassMixin, dj.Computed): time_of_sort: int # in Unix time, to the nearest second """ + _use_transaction, _allow_insert = False, True + def make(self, key: dict): """Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table. diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 6b3928042..354b492ab 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -7,6 +7,7 @@ from copy import deepcopy from enum import Enum from functools import cached_property +from hashlib import md5 as hash_md5 from itertools import chain as iter_chain from typing import Any, Dict, Iterable, List, Set, Tuple, Union @@ -595,6 +596,14 @@ def leaf_ft(self): """Get restricted FreeTables from graph leaves.""" return [self._get_ft(table, with_restr=True) for table in self.leaves] + @property + def hash(self): + """Return hash of all visited nodes.""" + initial = hash_md5(b"") + for table in self.all_ft: + initial.update(table.fetch()) + return initial.hexdigest() + # ------------------------------- Add Nodes ------------------------------- def add_leaf( diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 0bf61b734..caf6ea57c 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -516,7 +516,8 @@ def make_file_obj_id_unique(nwb_path: str): def populate_pass_function(value): """Pass function for parallel populate. - Note: To avoid pickling errors, the table must be passed by class, NOT by instance. + Note: To avoid pickling errors, the table must be passed by class, + NOT by instance. Note: This function must be defined in the global namespace. Parameters diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 04b873740..6ce94fbf0 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -81,6 +81,8 @@ class SpyglassMixin: _banned_search_tables = set() # Tables to avoid in restrict_by _parallel_make = False # Tables that use parallel processing in make + _use_transaction = True # Use transaction in populate. + def __init__(self, *args, **kwargs): """Initialize SpyglassMixin. @@ -410,7 +412,7 @@ def delete_downstream_parts( **kwargs : Any Passed to datajoint.table.Table.delete. """ - from spyglass.utils.dj_graph import RestrGraph # noqa F401 + RestrGraph = self._graph_deps[1] start = time() @@ -475,7 +477,14 @@ def _delete_deps(self) -> List[Table]: self._member_pk = LabMember.primary_key[0] return [LabMember, LabTeam, Session, schema.external, IntervalList] - def _get_exp_summary(self) -> Union[QueryExpression, None]: + @cached_property + def _graph_deps(self) -> list: + from spyglass.utils.dj_graph import RestrGraph # noqa #F401 + from spyglass.utils.dj_graph import TableChain + + return [TableChain, RestrGraph] + + def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. Parameters @@ -513,7 +522,7 @@ def _get_exp_summary(self) -> Union[QueryExpression, None]: @cached_property def _session_connection(self): """Path from Session table to self. False if no connection found.""" - from spyglass.utils.dj_graph import TableChain # noqa F401 + TableChain = self._graph_deps[0] return TableChain(parent=self._delete_deps[2], child=self, verbose=True) @@ -697,25 +706,104 @@ def super_delete(self, warn=True, *args, **kwargs): self._log_delete(start=time(), super_delete=True) super().delete(*args, **kwargs) - # -------------------------- non-daemon populate -------------------------- + # -------------------------------- populate -------------------------------- + + def _hash_upstream(self, keys): + """Hash upstream table keys for no transaction populate. + + Uses a RestrGraph to capture all upstream tables, restrict them to + relevant entries, and hash the results. This is used to check if + upstream tables have changed during a no-transaction populate and avoid + the following data-integrity error: + + 1. User A starts no-transaction populate. + 2. User B deletes and repopulates an upstream table, changing contents. + 3. User A finishes populate, inserting data that is now invalid. + + Parameters + ---------- + keys : list + List of keys for populating table. + """ + RestrGraph = self._graph_deps[1] + + if not (parents := self.parents(as_objects=True, primary=True)): + raise RuntimeError("No upstream tables found for upstream hash.") + + leaves = { # Restriction on each primary parent + p.full_table_name: [ + {k: v for k, v in key.items() if k in p.heading.names} + for key in keys + ] + for p in parents + } + + return RestrGraph(seed_table=self, leaves=leaves, cascade=True).hash + def populate(self, *restrictions, **kwargs): - """Populate table in parallel. + """Populate table in parallel, with or without transaction protection. Supersedes datajoint.table.Table.populate for classes with that - spawn processes in their make function + spawn processes in their make function and always use transactions. + + `_use_transaction` class attribute can be set to False to disable + transaction protection for a table. This is not recommended for tables + with short processing times. A before-and-after hash check is performed + to ensure upstream tables have not changed during populate, and may + be a more time-consuming process. To permit the `make` to insert without + populate, set `_allow_insert` to True. """ - - # Pass through to super if not parallel in the make function or only a single process processes = kwargs.pop("processes", 1) + + # Decide if using transaction protection + use_transact = kwargs.pop("use_transation", None) + if use_transact is None: # if user does not specify, use class default + use_transact = self._use_transaction + if self._use_transaction is False: # If class default is off, warn + logger.warning( + "Turning off transaction protection this table by default. " + + "Use use_transation=True to re-enable.\n" + + "Read more about transactions:\n" + + "https://docs.datajoint.io/python/definition/05-Transactions.html\n" + + "https://github.com/LorenFrankLab/spyglass/issues/1030" + ) + if use_transact is False and processes > 1: + raise RuntimeError( + "Must use transaction protection with parallel processing.\n" + + "Call with use_transation=True.\n" + + f"Table default transaction use: {self._use_transaction}" + ) + + # Get keys, needed for no-transact or multi-process w/_parallel_make + keys = [True] + if use_transact is False or (processes > 1 and self._parallel_make): + keys = (self._jobs_to_do(restrictions) - self.target).fetch( + "KEY", limit=kwargs.get("limit", None) + ) + + if use_transact is False: + upstream_hash = self._hash_upstream(keys) + if kwargs: # Warn of ignoring populate kwargs, bc using `make` + logger.warning( + "Ignoring kwargs when not using transaction protection." + ) + if processes == 1 or not self._parallel_make: - kwargs["processes"] = processes - return super().populate(*restrictions, **kwargs) + if use_transact: # Pass single-process populate to super + kwargs["processes"] = processes + return super().populate(*restrictions, **kwargs) + else: # No transaction protection, use bare make + for key in keys: + self.make(key) + if upstream_hash != self._hash_upstream(keys): + (self & keys).delete(force=True) + logger.error( + "Upstream tables changed during non-transaction " + + "populate. Please try again." + ) + return # If parallel in both make and populate, use non-daemon processes - # Get keys to populate - keys = (self._jobs_to_do(restrictions) - self.target).fetch( - "KEY", limit=kwargs.get("limit", None) - ) # package the call list call_list = [(type(self), key, kwargs) for key in keys] @@ -964,7 +1052,7 @@ def restrict_by( Restricted version of present table or TableChain object. If return_graph, use all_ft attribute to see all tables in cascade. """ - from spyglass.utils.dj_graph import TableChain # noqa: F401 + TableChain = self._graph_deps[0] if restriction is True: return self From d4dbc232dcb474f037493d8bfbc28afa22ba9eaf Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Thu, 29 Aug 2024 09:29:23 -0700 Subject: [PATCH 3/3] Prevent error from unitless spike group (#1083) * prevent error from unitless spike group * fix 1077 * update changelog --- CHANGELOG.md | 5 +++++ src/spyglass/decoding/v1/waveform_features.py | 11 ++++++++--- src/spyglass/spikesorting/spikesorting_merge.py | 6 +++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23311bf14..e1afbe680 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ - Disable populate transaction protection for long-populating tables #1066 +### Pipelines + +- Decoding + - Fix edge case errors in spike time loading #1083 + ## [0.5.3] (August 27, 2024) ### Infrastructure diff --git a/src/spyglass/decoding/v1/waveform_features.py b/src/spyglass/decoding/v1/waveform_features.py index 1208c53fd..56176484d 100644 --- a/src/spyglass/decoding/v1/waveform_features.py +++ b/src/spyglass/decoding/v1/waveform_features.py @@ -152,9 +152,12 @@ def make(self, key): sorter, ) - spike_times = SpikeSortingOutput().fetch_nwb(merge_key)[0][ - analysis_nwb_key - ]["spike_times"] + nwb = SpikeSortingOutput().fetch_nwb(merge_key)[0] + spike_times = ( + nwb[analysis_nwb_key]["spike_times"] + if analysis_nwb_key in nwb + else pd.DataFrame() + ) ( key["analysis_file_name"], @@ -349,6 +352,8 @@ def _write_waveform_features_to_nwb( metric_dict[unit_id] if unit_id in metric_dict else [] for unit_id in unit_ids ] + if not metric_values: + metric_values = np.array([]).astype(np.float32) nwbf.add_unit_column( name=metric, description=metric, diff --git a/src/spyglass/spikesorting/spikesorting_merge.py b/src/spyglass/spikesorting/spikesorting_merge.py index 4887cb3f3..7d12601e2 100644 --- a/src/spyglass/spikesorting/spikesorting_merge.py +++ b/src/spyglass/spikesorting/spikesorting_merge.py @@ -4,9 +4,9 @@ from ripple_detection import get_multiunit_population_firing_rate from spyglass.spikesorting.imported import ImportedSpikeSorting # noqa: F401 -from spyglass.spikesorting.v0.spikesorting_curation import ( +from spyglass.spikesorting.v0.spikesorting_curation import ( # noqa: F401 CuratedSpikeSorting, -) # noqa: F401 +) from spyglass.spikesorting.v1 import ArtifactDetectionSelection # noqa: F401 from spyglass.spikesorting.v1 import ( CurationV1, @@ -210,7 +210,7 @@ def get_spike_indicator(cls, key, time): """ time = np.asarray(time) min_time, max_time = time[[0, -1]] - spike_times = cls.fetch_spike_data(key) # CB: This is undefined. + spike_times = (cls & key).get_spike_times(key) spike_indicator = np.zeros((len(time), len(spike_times))) for ind, times in enumerate(spike_times):