Skip to content

Commit

Permalink
resolve case of string restrictions
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Sep 16, 2024
1 parent 138371f commit 39080b9
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pprint import pprint
from re import sub as re_sub
from time import time
from typing import Union
from typing import Union, List

import datajoint as dj
from datajoint.condition import make_condition
Expand Down Expand Up @@ -532,16 +532,17 @@ def fetch_nwb(
if isinstance(self, dict):
raise ValueError("Try replacing Merge.method with Merge().method")
restriction = restriction or self.restriction or True
sources = set((self & restriction).fetch(self._reserved_sk))
merge_restriction = extract_merge_id(restriction)
sources = set((self & merge_restriction).fetch(self._reserved_sk))
nwb_list = []
merge_ids = []
for source in sources:
source_restr = (
self & {self._reserved_sk: source} & restriction
self & {self._reserved_sk: source} & merge_restriction
).fetch("KEY")
nwb_list.extend(
(self & source_restr)
.merge_restrict_class(restriction, permit_multiple_rows=True)
.merge_restrict_class(restriction, permit_multiple_rows=True, add_invalid_restrict=False)
.fetch_nwb()
)
if return_merge_ids:
Expand Down Expand Up @@ -737,10 +738,10 @@ def merge_get_parent_class(self, source: str) -> dj.Table:
return ret

def merge_restrict_class(
self, key: dict, permit_multiple_rows: bool = False
self, key: dict, permit_multiple_rows: bool = False, add_invalid_restrict=True
) -> dj.Table:
"""Returns native parent class, restricted with key."""
parent = self.merge_get_parent(key)
parent = self.merge_get_parent(key, add_invalid_restrict=add_invalid_restrict)
parent_key = parent.fetch("KEY", as_dict=True)

if not permit_multiple_rows and len(parent_key) > 1:
Expand Down Expand Up @@ -860,3 +861,35 @@ def delete_downstream_merge(
table = table if isinstance(table, dj.Table) else table()

return table.delete_downstream_parts(**kwargs)

def extract_merge_id(restriction):
"""Utility function to extract merge_id from a restriction
Parameters
----------
restriction : str, dict, or dj.condition.AndList
A datajoint restriction
Returns
-------
restriction
A restriction containing only the merge_id key
"""
if restriction is None:
return None
if isinstance(restriction, dict):
if merge_id := restriction.get("merge_id"):
return {"merge_id": merge_id}
else:
return {}
merge_restr = []
if isinstance(restriction, dj.condition.AndList) or isinstance(restriction, List):
merge_id_list = [extract_merge_id(r) for r in restriction]
merge_restr = [x for x in merge_id_list if x is not None]
elif isinstance(restriction, str):
parsed = [x.split(")")[0] for x in restriction.split("(") if x]
merge_restr = dj.condition.AndList([x for x in parsed if "merge_id" in x])

if len(merge_restr) == 0:
return True
return merge_restr

0 comments on commit 39080b9

Please sign in to comment.