Skip to content

Commit

Permalink
Parallelizing _scan_file_dimensions()
Browse files Browse the repository at this point in the history
Using Schwimmbad and multiprocessing to parallelize extracting
the dimensions of files in HSCDataSet to effect speedup in 10M+
file datasets.

Not currently tuned to hyak, no speedup yet measured.
  • Loading branch information
mtauraso committed Nov 15, 2024
1 parent b1acf8c commit 3284a22
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"toml", # Used to load configuration files as dictionaries
"torch", # Used for CNN model and in train.py
"torchvision", # Used in hsc data loader, example autoencoder, and CNN model data set
"schwimmbad", # Used to speedup hsc data loader file scans
]

[project.scripts]
Expand Down
31 changes: 18 additions & 13 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from astropy.io import fits
from astropy.table import Table
from schwimmbad import MultiPool
from torch.utils.data import Dataset
from torchvision.transforms.v2 import CenterCrop, Compose, Lambda

Expand Down Expand Up @@ -487,14 +488,22 @@ def _scan_file_dimensions(self) -> dim_dict:
logger.info("Scanning for dimensions...")

retval = {}
for index, object_id in enumerate(self.ids()):
retval[object_id] = [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)]
if index != 0 and index % 100_000 == 0:
logger.info(f"Scanned {index} objects for dimensions")
else:
logger.info(f"Scanned {index+1} objects for dimensions")

with MultiPool() as pool:
args = ((object_id, list(self._object_files(object_id))) for object_id in self.ids())
retval = dict(pool.map(self._scan_file_dimension, args))
return retval

@staticmethod
def _scan_file_dimension(processing_unit: tuple[str, list[str]]) -> list[tuple[int, int]]:
object_id, filenames = processing_unit
return (object_id, [HSCDataSetContainer._fits_file_dims(filepath) for filepath in filenames])

@staticmethod
def _fits_file_dims(filepath):
with fits.open(filepath) as hdul:
return hdul[1].shape

def _prune_objects(self, filters_ref: list[str]):
"""Class initialization helper. Prunes objects from the list of objects.
Expand Down Expand Up @@ -563,10 +572,6 @@ def _prune_object(self, object_id, reason: str):
del self.dims[object_id]
self.prune_count += 1

def _fits_file_dims(self, filepath):
with fits.open(filepath) as hdul:
return hdul[1].shape

def _check_file_dimensions(self) -> tuple[int, int]:
"""Class initialization helper. Find the maximal pixel size that all images can support
Expand All @@ -589,10 +594,10 @@ def _check_file_dimensions(self) -> tuple[int, int]:

# Find the maximal cutout size that all images can support
all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list]
cutout_width = np.min(all_widths)

all_heights = [shape[1] for shape_list in self.dims.values() for shape in shape_list]
cutout_height = np.min(all_heights)
all_dimensions = all_widths + all_heights
cutout_height = np.min(all_dimensions)
cutout_width = cutout_height

if (
np.abs(cutout_width - np.mean(all_widths)) > 1
Expand Down
12 changes: 6 additions & 6 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def test_load(caplog):
# 10 objects should load
assert len(a) == 10

# The number of filters, and image dimensions should be correct
assert a.shape() == (5, 262, 263)
# The number of filters, and image dimensions should be correct and square
assert a.shape() == (5, 262, 262)

# No warnings should be printed
assert caplog.text == ""
Expand All @@ -152,8 +152,8 @@ def test_load_duplicate(caplog):
# Only 10 objects should load
assert len(a) == 10

# The number of filters, and image dimensions should be correct
assert a.shape() == (5, 262, 263)
# The number of filters, and image dimensions should be correct and square
assert a.shape() == (5, 262, 262)

# We should get duplicate object errors
assert "Duplicate object ID" in caplog.text
Expand Down Expand Up @@ -327,8 +327,8 @@ def test_partial_filter(caplog):
# 10 objects should load
assert len(a) == 10

# The number of filters, and image dimensions should be correct
assert a.shape() == (2, 262, 263)
# The number of filters, and image dimensions should be correct and square
assert a.shape() == (2, 262, 262)

# No warnings should be printed
assert caplog.text == ""
Expand Down

0 comments on commit 3284a22

Please sign in to comment.