Skip to content

Commit

Permalink
Merge pull request #164 from /issues/160/more-transforms
Browse files Browse the repository at this point in the history
Adding the Capability for any numpy function to be passed as a transform variable
  • Loading branch information
aritraghsh09 authored Jan 16, 2025
2 parents c86ab5e + 98663ab commit 11a235a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 8 deletions.
37 changes: 29 additions & 8 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import resource
from copy import copy, deepcopy
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -318,17 +318,17 @@ def __getitem__(self, idx: int) -> torch.Tensor:

class HSCDataSetContainer(Dataset):
def __init__(self, config):
# TODO: What will be a reasonable set of tranformations?
# For now tanh all the values so they end up in [-1,1]
# Another option might be sinh, but we'd need to mess with the example autoencoder module
# Because it goes from unbounded NN output space -> [-1,1] with tanh in its decode step.
transform = Lambda(lambd=np.tanh)

crop_to = config["data_set"]["crop_to"]
filters = config["data_set"]["filters"]

transform_str = config["data_set"]["transform"]
self.use_cache = config["data_set"]["use_cache"]

if transform_str:
transform_func = self._get_np_function(transform_str)
transform = Lambda(lambd=transform_func)
else:
transform = None

if config["data_set"]["filter_catalog"]:
filter_catalog = Path(config["data_set"]["filter_catalog"])
elif not config.get("rebuild_manifest", False):
Expand All @@ -348,6 +348,27 @@ def __init__(self, config):
filter_catalog=Path(filter_catalog) if filter_catalog else None,
)

def _get_np_function(self, transform_str: str) -> Callable[..., Any]:
"""
_get_np_function. Returns the numpy mathematical function that the
supplied string maps to; or raises an error if the supplied string
cannot be mapped to a function.
Parameters
----------
transform_str: str
The string to me mapped to a numpy function
"""

try:
func: Callable[..., Any] = getattr(np, transform_str)
if callable(func):
return func
except AttributeError as err:
msg = f"{transform_str} is not a valid numpy function.\n"
msg += "The string passed to the transform variable needs to be a numpy function"
raise RuntimeError(msg) from err

def _init_from_path(
self,
path: Union[Path, str],
Expand Down
5 changes: 5 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ filters = false
# [general].data_dir. Implementation is data_set class dependent. Use `false` for no filtering.
filter_catalog = false

# The transformation to be applied to images before being passed on to the model
# This must be a valid Numpy function. Passing false will result in no transformations
# (other than cropping) be applied to the images.
transform = "tanh"

# train_size, validation_size, and test_size use these conventions:
# * A `float` between `0.0` and `1.0` is the proportion of the dataset to include in the split.
# * An `int`, represents the absolute number of samples in the particular split.
Expand Down
55 changes: 55 additions & 0 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from fibad.data_sets.hsc_data_set import HSCDataSet, HSCDataSetSplit
from torchvision.transforms.v2 import CenterCrop, Lambda

test_dir = Path(__file__).parent / "test_data" / "dataloader"

Expand Down Expand Up @@ -64,6 +65,7 @@ def mkconfig(
seed=False,
filter_catalog=False,
use_cache=False,
transform="tanh",
):
"""Makes a configuration that points at nonexistent path so HSCDataSet.__init__ will create an object,
and our FakeFitsFS shim can be called.
Expand All @@ -79,6 +81,7 @@ def mkconfig(
"test_size": test_size,
"validate_size": validate_size,
"use_cache": use_cache,
"transform": transform,
},
}

Expand Down Expand Up @@ -630,3 +633,55 @@ def test_split_and_conflicting_datasets():

with pytest.raises(RuntimeError):
a.current_split.logical_and(b.current_split)


def test_valid_transform_string(caplog):
"""Test to ensure that a valid string passed to transform
will map to a numpy function"""

caplog.set_level(logging.ERROR)
test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263))

with FakeFitsFS(test_files):
a = HSCDataSet(mkconfig(transform="arcsinh"), split=None)

# transform always has CenterCrop in the beginning followed by the user
# defined transform
lambda_transform = [t for t in a.container.transform.transforms if isinstance(t, Lambda)][0]
assert lambda_transform.lambd == np.arcsinh

with FakeFitsFS(test_files):
a = HSCDataSet(mkconfig(transform="tanh"), split=None)

# transform always has CenterCrop in the beginning followed by the user
# defined transform
lambda_transform = [t for t in a.container.transform.transforms if isinstance(t, Lambda)][0]
assert lambda_transform.lambd == np.tanh


def test_invalid_transform_string(caplog):
"""Test to ensure that an invalid string passed to transform will raise an error"""

caplog.set_level(logging.ERROR)
test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263))

with FakeFitsFS(test_files):
with pytest.raises(RuntimeError):
HSCDataSet(mkconfig(transform="invalid_function"), split=None)


def test_false_transform(caplog):
"""Test to ensure that false passed to transform behaves as expected"""

caplog.set_level(logging.ERROR)
test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263))

with FakeFitsFS(test_files):
a = HSCDataSet(mkconfig(transform=False), split=None)

# When transform is False; only a CenterCrop should be applied
# automatically with a size conforming to test_files above
expected_transform = CenterCrop(size=(np.int64(262), np.int64(262)))
actual_transform = a.container.transform
assert isinstance(actual_transform, CenterCrop)
assert actual_transform.size == expected_transform.size

0 comments on commit 11a235a

Please sign in to comment.