diff --git a/src/fibad/data_sets/hsc_data_set.py b/src/fibad/data_sets/hsc_data_set.py index e4310ba..3c2c80f 100644 --- a/src/fibad/data_sets/hsc_data_set.py +++ b/src/fibad/data_sets/hsc_data_set.py @@ -5,6 +5,7 @@ import os import re import resource +from collections.abc import Generator from copy import copy, deepcopy from pathlib import Path from typing import Any, Callable, Literal, Optional, Union @@ -157,6 +158,9 @@ def _set_split(self, split: Union[str, None] = None): def shape(self) -> tuple[int, int, int]: return self.container.shape() + def ids(self) -> Generator[str, None, None]: + return self.container.ids() + def __getitem__(self, idx: int) -> torch.Tensor: # return self.current_split[idx] return self.container[idx] @@ -921,7 +925,7 @@ def _get_file(self, index: int) -> Path: filter = filter_names[index % self.num_filters] return self._file_to_path(filters[filter]) - def ids(self, log_every=None): + def ids(self, log_every=None) -> Generator[str, None, None]: """Public read-only iterator over all object_ids that enforces a strict total order across objects. Will not work prior to self.files initialization in __init__ diff --git a/src/fibad/infer.py b/src/fibad/infer.py index ca0411f..18f4295 100644 --- a/src/fibad/infer.py +++ b/src/fibad/infer.py @@ -56,7 +56,7 @@ def run(config: ConfigDict): write_index = 0 batch_index = 0 - object_ids: list[int] = list(int(data_set.ids()) if hasattr(data_set, "ids") else range(len(data_set))) # type: ignore[arg-type] + object_ids = list(data_set.ids() if hasattr(data_set, "ids") else range(len(data_set))) # type: ignore[arg-type] def _save_batch(batch_results: Tensor): """Receive and write results tensors to results_dir immediately