diff --git a/pyns/models/analysis.py b/pyns/models/analysis.py index 4f81324..65bd93a 100644 --- a/pyns/models/analysis.py +++ b/pyns/models/analysis.py @@ -385,15 +385,40 @@ def _ts_first(paths): return req - def get_uploads(self, id): + def get_uploads(self, id, select='latest', **kwargs): """ Get NeuroVault uploads associated with this analysis :param str id: Analysis hash_id. + :param str select: How to select from multiple collections. + Options: "latest", "oldest" or None. If None, returns all results. + :param dict kwargs: Attributes to filter collections on. + If any attributes are not found, they are ignored. :return: client response object """ - return self.get(id=id, sub_route='upload') + uploads = self.get(id=id, sub_route='upload') + + # Sort by date + uploads = sorted(uploads, key=lambda x: datetime.datetime.strptime( + x['uploaded_at'], '%Y-%m-%dT%H:%M'), + reverse=(select == 'latest')) + + # Select collections based on filters + uploads = [ + u for u in uploads + if all([u.get(k) == v for k, v in kwargs.items() if k in u]) + ] + + if not uploads: + return None + + # Select first item unless all are requested + if select is not None: + uploads = [uploads[0]] + + return uploads def load_uploads(self, id, select='latest', - download_dir=None, **kwargs): + download_dir=None, collection_filters={}, + image_filters={}): """ Load collection upload as NiBabel images and associated meta-data You can filter which images are loaded based on either collection level attributes or statmap image level attributes. These correspond @@ -405,7 +430,8 @@ def load_uploads(self, id, select='latest', :param str select: How to select from multiple collections. Options: "latest", "oldest" or None. If None, returns all results. :param str download_dir: Path to download images. If None, tempdir. - :param dict kwargs: Attributes to filter images on. + :param dict collection_filters: Attributes to filter collections on. + :param dict image_filters: Attributes to filter images on. If any attributes are not found, they are ignored. :return list list of tuples of format (Nifti1Image, kwargs). """ @@ -415,30 +441,15 @@ def load_uploads(self, id, select='latest', download_dir = Path(download_dir) # Sort uploads for upload date - uploads = self.get_uploads(id) - for u in uploads: - u['uploaded_at'] = datetime.datetime.strptime( - u['uploaded_at'], '%Y-%m-%dT%H:%M') - uploads = sorted(uploads, key=lambda x: x['uploaded_at'], - reverse=(select == 'latest')) - - # Select collections based on filters - uploads = [ - u for u in uploads - if all([u.get(k) == v for k, v in kwargs.items() if k in u]) - ] - + uploads = self.get_uploads(id, **collection_filters) + if not uploads: return None - # Select first item unless all are requested - if select is not None: - uploads = [uploads[0]] - # Extract entities from file path def _get_entities(path): di = {} - for t in ['task', 'contrast', 'stat']: + for t in ['task', 'contrast', 'stat', 'space']: matches = re.findall(f"{t}-(.*?)_", path) if matches: di[t] = matches[0] @@ -452,7 +463,7 @@ def _get_entities(path): # If file matches kwargs and is in NV if f.pop('status') == 'OK' and all( - [f.get(k, None) == v for k, v in kwargs.items() if k in f]): + [f.get(k, None) == v if k in f else False for k, v in image_filters.items()]): # Download and open img_url = "https://neurovault.org/media/images/" \ f"{u['collection_id']}/{f['basename']}" @@ -482,7 +493,7 @@ def plot_uploads(self, id, plot_args={}, **kwargs): plots = [] for niimg, _ in images: plots.append( - nilearn.plotting.plot_stat_map(niimg, **plot_args)) + niplt.plot_stat_map(niimg, **plot_args)) return plots else: diff --git a/pyns/models/utils.py b/pyns/models/utils.py index c55a281..f7f3861 100644 --- a/pyns/models/utils.py +++ b/pyns/models/utils.py @@ -79,3 +79,8 @@ def build_model(name, variables, tasks, subjects, runs=None, session=None, model['Input']['Session'] = session return model + + +def snake_to_camel(string): + words = string.split('_') + return words[0] + ''.join(word.title() for word in words[1:]) \ No newline at end of file diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 234c2d8..65040fa 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -254,7 +254,8 @@ def test_load_analysis(recorder, neuroscout): m = neuroscout.analyses.get_analysis('gbp6i') assert len(m.load_uploads()) == 3 assert len(m.load_uploads(select=None)) == 3 - nistats = m.load_uploads(estimator='nistats') + nistats = m.load_uploads( + collection_filters=dict(estimator='nistats')) latest_up = nistats[0][1]['uploaded_at'] assert len(nistats) == 3 assert nistats[0][1]['estimator'] == 'nistats'