diff --git a/hartufo/query.py b/hartufo/query.py index 1afe47f..8810955 100644 --- a/hartufo/query.py +++ b/hartufo/query.py @@ -2,7 +2,7 @@ import csv from pathlib import Path import warnings -import random +import re from abc import abstractmethod from numbers import Number from typing import Dict, Union @@ -18,6 +18,9 @@ from torchvision.datasets.utils import download_url, download_and_extract_archive, check_integrity +_SUBJECT_RE = re.compile('(first|last|random)(\d*)') + + _CIPIC_ANTHROPOMETRY_NAMES = { 'weight': ('weight',), 'age': ('age',), @@ -183,12 +186,23 @@ def specification_based_ids(self, specification, include_subjects=None, exclude_ + sorted([int_id for int_id in selected_ids if isinstance(int_id[0], int)])) if include_subjects is None: return ids - if len(ids) > 0 and include_subjects == 'first': - return [ids[0]] - if len(ids) > 0 and include_subjects == 'last': - return [ids[-1]] - if len(ids) > 0 and include_subjects == 'random': - return [random.choice(ids)] + if len(ids) > 0 and isinstance(include_subjects, str): + subj_match = _SUBJECT_RE.match(include_subjects) + position = subj_match.group(1) + try: + num_ears = int(subj_match.group(2)) + except ValueError: + num_ears = 1 + if side.startswith('any') or side.startswith('both'): + num_ears *= 2 + if position == 'first': + return ids[:num_ears] + elif position == 'last': + return ids[-num_ears:] + elif position == 'random': + return np.random.choice(ids, num_ears) + else: + raise ValueError(f'Unknown subject selector "{include_subjects}".') return [(i, s) for i, s in ids if i in include_subjects]