Skip to content

Commit

Permalink
Add method to pass along the HSCDataset id when populating the vector…
Browse files Browse the repository at this point in the history
… database.
  • Loading branch information
drewoldag committed Jan 23, 2025
1 parent 839cf98 commit b0e206b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import resource
from collections.abc import Generator
from copy import copy, deepcopy
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
Expand Down Expand Up @@ -157,6 +158,9 @@ def _set_split(self, split: Union[str, None] = None):
def shape(self) -> tuple[int, int, int]:
return self.container.shape()

def ids(self) -> Generator[str, None, None]:
return self.container.ids()

Check warning on line 162 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L162

Added line #L162 was not covered by tests

def __getitem__(self, idx: int) -> torch.Tensor:
# return self.current_split[idx]
return self.container[idx]
Expand Down Expand Up @@ -921,7 +925,7 @@ def _get_file(self, index: int) -> Path:
filter = filter_names[index % self.num_filters]
return self._file_to_path(filters[filter])

def ids(self, log_every=None):
def ids(self, log_every=None) -> Generator[str, None, None]:
"""Public read-only iterator over all object_ids that enforces a strict total order across
objects. Will not work prior to self.files initialization in __init__
Expand Down
2 changes: 1 addition & 1 deletion src/fibad/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run(config: ConfigDict):

write_index = 0
batch_index = 0
object_ids: list[int] = list(int(data_set.ids()) if hasattr(data_set, "ids") else range(len(data_set))) # type: ignore[arg-type]
object_ids = list(data_set.ids() if hasattr(data_set, "ids") else range(len(data_set))) # type: ignore[arg-type]

Check warning on line 59 in src/fibad/infer.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/infer.py#L59

Added line #L59 was not covered by tests

def _save_batch(batch_results: Tensor):
"""Receive and write results tensors to results_dir immediately
Expand Down

0 comments on commit b0e206b

Please sign in to comment.