Skip to content

Commit

Permalink
move extract_merge_id to class method
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Sep 16, 2024
1 parent 39080b9 commit f685dad
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def fetch_nwb(
if isinstance(self, dict):
raise ValueError("Try replacing Merge.method with Merge().method")
restriction = restriction or self.restriction or True
merge_restriction = extract_merge_id(restriction)
merge_restriction = self.extract_merge_id(restriction)
sources = set((self & merge_restriction).fetch(self._reserved_sk))
nwb_list = []
merge_ids = []
Expand Down Expand Up @@ -834,6 +834,39 @@ def super_delete(self, warn=True, *args, **kwargs):
self._log_delete(start=time(), super_delete=True)
super().delete(*args, **kwargs)

@classmethod
def extract_merge_id(cls, restriction):
"""Utility function to extract merge_id from a restriction
Parameters
----------
restriction : str, dict, or dj.condition.AndList
A datajoint restriction
Returns
-------
restriction
A restriction containing only the merge_id key
"""
if restriction is None:
return None
if isinstance(restriction, dict):
if merge_id := restriction.get("merge_id"):
return {"merge_id": merge_id}
else:
return {}
merge_restr = []
if isinstance(restriction, dj.condition.AndList) or isinstance(restriction, List):
merge_id_list = [cls.extract_merge_id(r) for r in restriction]
merge_restr = [x for x in merge_id_list if x is not None]
elif isinstance(restriction, str):
parsed = [x.split(")")[0] for x in restriction.split("(") if x]
merge_restr = dj.condition.AndList([x for x in parsed if "merge_id" in x])

if len(merge_restr) == 0:
return True
return merge_restr


_Merge = Merge

Expand Down Expand Up @@ -861,35 +894,3 @@ def delete_downstream_merge(
table = table if isinstance(table, dj.Table) else table()

return table.delete_downstream_parts(**kwargs)

def extract_merge_id(restriction):
"""Utility function to extract merge_id from a restriction
Parameters
----------
restriction : str, dict, or dj.condition.AndList
A datajoint restriction
Returns
-------
restriction
A restriction containing only the merge_id key
"""
if restriction is None:
return None
if isinstance(restriction, dict):
if merge_id := restriction.get("merge_id"):
return {"merge_id": merge_id}
else:
return {}
merge_restr = []
if isinstance(restriction, dj.condition.AndList) or isinstance(restriction, List):
merge_id_list = [extract_merge_id(r) for r in restriction]
merge_restr = [x for x in merge_id_list if x is not None]
elif isinstance(restriction, str):
parsed = [x.split(")")[0] for x in restriction.split("(") if x]
merge_restr = dj.condition.AndList([x for x in parsed if "merge_id" in x])

if len(merge_restr) == 0:
return True
return merge_restr

0 comments on commit f685dad

Please sign in to comment.