-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MultiDimImageDataset metadata handling #458
base: main
Are you sure you want to change the base?
Changes from 4 commits
1ab13db
6ee7c13
98c6599
1630d32
f173948
2da4d49
be04097
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,49 @@ | ||
from pathlib import Path | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
from typing import Callable, Dict, Optional, Sequence, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
import tqdm | ||
from bioio import BioImage | ||
from monai.data import DataLoader, Dataset, MetaTensor | ||
from monai.transforms import Compose, ToTensor, apply_transform | ||
from omegaconf import ListConfig | ||
from monai.data import CacheDataset | ||
from omegaconf import OmegaConf | ||
|
||
|
||
class MultiDimImageDataset(Dataset): | ||
"""Dataset converting a `.csv` file listing multi dimensional (timelapse or multi-scene) files | ||
and some metadata into batches of single- scene, single-timepoint, single-channel images.""" | ||
class MultiDimImageDataset(CacheDataset): | ||
"""Dataset converting a `.csv` file or dictionary listing multi dimensional (timelapse or | ||
multi-scene) files and some metadata into batches of metadata intended for the | ||
BioIOImageLoaderd class.""" | ||
|
||
def __init__( | ||
self, | ||
csv_path: Union[Path, str], | ||
img_path_column: str, | ||
channel_column: str, | ||
out_key: str, | ||
csv_path: Optional[Union[Path, str]] = None, | ||
img_path_column: str = "path", | ||
channel_column: str = "channel", | ||
spatial_dims: int = 3, | ||
scene_column: str = "scene", | ||
resolution_column: str = "resolution", | ||
time_start_column: str = "start", | ||
time_stop_column: str = "stop", | ||
time_step_column: str = "step", | ||
dict_meta: Optional[Dict] = None, | ||
transform: Optional[Callable] = None, | ||
dask_load: bool = True, | ||
transform: Optional[Union[Callable, Sequence[Callable]]] = [], | ||
**cache_kwargs, | ||
): | ||
""" | ||
Parameters | ||
Parameterss | ||
---------- | ||
csv_path: Union[Path, str] | ||
path to csv | ||
img_path_column: str | ||
column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file | ||
channel_column:str | ||
Column in `csv_path` that contains which channel to extract from multi dimensional (timelapse or multi-scene) file. Should be an integer. | ||
out_key:str | ||
Key where single-scene/timepoint/channel is saved in output dictionary | ||
spatial_dims:int=3 | ||
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX | ||
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX. Spatial dimensions are used to specify the dimension order of the output image, which will be in the format `CZYX` or `CYX` to ensure compatibility with dictionary-based MONAI-style transforms. | ||
scene_column:str="scene", | ||
Column in `csv_path` that contains scenes to extract from multi-scene file. If not specified, all scenes will | ||
be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`) | ||
resolution_column:str="resolution" | ||
Column in `csv_path` that contains resolution to extract from multi-resolution file. If not specified, resolution is assumed to be 0. | ||
time_start_column:str="start" | ||
Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column` | ||
are not specified, all timepoints are extracted. | ||
|
@@ -56,27 +55,29 @@ def __init__( | |
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. | ||
dict_meta: Optional[Dict] | ||
Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`. | ||
transform: Optional[Callable] = None | ||
Callable to that accepts numpy array. For example, image normalization functions could be passed here. | ||
dask_load: bool = True | ||
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. | ||
transform: Optional[Callable] = [] | ||
List (or Compose Object) or Monai dictionary-style transforms to apply to the image metadata. Typically, the first transform should be BioIOImageLoaderd. | ||
cache_kwargs: | ||
Additional keyword arguments to pass to `CacheDataset`. To skip the caching mechanism, set `cache_num` to 0. | ||
""" | ||
super().__init__(None, transform) | ||
df = pd.read_csv(csv_path) if csv_path is not None else pd.DataFrame([dict_meta]) | ||
|
||
df = ( | ||
pd.read_csv(csv_path) | ||
if csv_path is not None | ||
else pd.DataFrame(OmegaConf.to_container(dict_meta)) | ||
) | ||
self.img_path_column = img_path_column | ||
self.channel_column = channel_column | ||
self.scene_column = scene_column | ||
self.resolution_column = resolution_column | ||
self.time_start_column = time_start_column | ||
self.time_stop_column = time_stop_column | ||
self.time_step_column = time_step_column | ||
self.out_key = out_key | ||
if spatial_dims not in (2, 3): | ||
raise ValueError(f"`spatial_dims` must be 2 or 3, got {spatial_dims}") | ||
self.spatial_dims = spatial_dims | ||
self.dask_load = dask_load | ||
data = self.get_per_file_args(df) | ||
|
||
self.img_data = self.get_per_file_args(df) | ||
super().__init__(data, transform, **cache_kwargs) | ||
|
||
def _get_scenes(self, row, img): | ||
scenes = row.get(self.scene_column, -1) | ||
|
@@ -100,145 +101,24 @@ def _get_timepoints(self, row, img): | |
|
||
def get_per_file_args(self, df): | ||
img_data = [] | ||
for row in df.itertuples(): | ||
for row in tqdm.tqdm(df.itertuples()): | ||
row_data = [] | ||
row = row._asdict() | ||
img = BioImage(row[self.img_path_column]) | ||
scenes = self._get_scenes(row, img) | ||
timepoints = self._get_timepoints(row, img) | ||
for scene in scenes: | ||
img.set_scene(scene) | ||
timepoints = self._get_timepoints(row, img) | ||
for timepoint in timepoints: | ||
img_data.append( | ||
row_data.append( | ||
{ | ||
"img": img, | ||
"dimension_order_out": "ZYX"[-self.spatial_dims :], | ||
"dimension_order_out": "C" + "ZYX"[-self.spatial_dims :], | ||
"C": row[self.channel_column], | ||
"scene": scene, | ||
"T": timepoint, | ||
"original_path": row[self.img_path_column], | ||
"resolution": row.get(self.resolution_column, 0), | ||
} | ||
) | ||
img_data.reverse() | ||
img_data.extend(row_data) | ||
return img_data | ||
|
||
def _metadata_to_str(self, metadata): | ||
return "_".join([] + [f"{k}={v}" for k, v in metadata.items()]) | ||
|
||
def _ensure_channel_first(self, img): | ||
while len(img.shape) < self.spatial_dims + 1: | ||
img = np.expand_dims(img, 0) | ||
return img | ||
|
||
def create_metatensor(self, img, meta): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we not doing all this anymore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part was just baking in some transforms in an easily-broken way. I think it's better to leave all of this to downstream transforms (for example, all of this is handled by the bioio image loader) |
||
if isinstance(img, np.ndarray): | ||
img = torch.from_numpy(img.astype(float)) | ||
if isinstance(img, MetaTensor): | ||
img.meta.update(meta) | ||
return img | ||
elif isinstance(img, torch.Tensor): | ||
return MetaTensor( | ||
img, | ||
meta=meta, | ||
) | ||
raise ValueError(f"Expected img to be MetaTensor or torch.Tensor, got {type(img)}") | ||
|
||
def is_batch(self, x): | ||
return isinstance(x, list) or len(x.shape) == self.spatial_dims + 2 | ||
|
||
def _transform(self, index: int): | ||
img_data = self.img_data.pop() | ||
img = img_data.pop("img") | ||
original_path = img_data.pop("original_path") | ||
scene = img_data.pop("scene") | ||
img.set_scene(scene) | ||
|
||
if self.dask_load: | ||
data_i = img.get_image_dask_data(**img_data).compute() | ||
else: | ||
data_i = img.get_image_data(**img_data) | ||
# add scene and path information back to metadata | ||
img_data["scene"] = scene | ||
img_data["original_path"] = original_path | ||
data_i = self._ensure_channel_first(data_i) | ||
data_i = self.create_metatensor(data_i, img_data) | ||
|
||
output_img = ( | ||
apply_transform(self.transform, data_i) if self.transform is not None else data_i | ||
) | ||
# some monai transforms return a batch. When collated, the batch dimension gets moved to the channel dimension | ||
if self.is_batch(output_img): | ||
return [{self.out_key: img} for img in output_img] | ||
return {self.out_key: output_img} | ||
|
||
def __len__(self): | ||
return len(self.img_data) | ||
|
||
|
||
def make_multidim_image_dataloader( | ||
csv_path: Optional[Union[Path, str]] = None, | ||
img_path_column: str = "path", | ||
channel_column: str = "channel", | ||
out_key: str = "image", | ||
spatial_dims: int = 3, | ||
scene_column: str = "scene", | ||
time_start_column: str = "start", | ||
time_stop_column: str = "stop", | ||
time_step_column: str = "step", | ||
dict_meta: Optional[Dict] = None, | ||
transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] = None, | ||
**dataloader_kwargs, | ||
) -> DataLoader: | ||
"""Function to create a MultiDimImage DataLoader. Currently, this dataset is only useful during | ||
prediction and cannot be used for training or testing. | ||
|
||
Parameters | ||
---------- | ||
csv_path: Optional[Union[Path, str]] | ||
path to csv | ||
img_path_column: str | ||
column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file | ||
channel_column: str | ||
Column in `csv_path` that contains which channel to extract from multi dim image file. Should be an integer. | ||
out_key: str | ||
Key where single-scene/timepoint/channel is saved in output dictionary | ||
spatial_dims: int | ||
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX | ||
scene_column: str | ||
Column in `csv_path` that contains scenes to extract from multiscene file. If not specified, all scenes will | ||
be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`) | ||
time_start_column: str | ||
Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column` | ||
are not specified, all timepoints are extracted. | ||
time_stop_column: str | ||
Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column` | ||
are not specified, all timepoints are extracted. | ||
time_step_column: str | ||
Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run. | ||
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. | ||
dict_meta: Optional[Dict] | ||
Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`. | ||
transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] | ||
Callable or list of callables that accept numpy array. For example, image normalization functions could be passed here. Dataloading is already handled by the dataset. | ||
|
||
Returns | ||
------- | ||
DataLoader | ||
The DataLoader object for the MultiDimIMage dataset. | ||
""" | ||
if isinstance(transforms, (list, tuple, ListConfig)): | ||
transforms = Compose(transforms) | ||
dataset = MultiDimImageDataset( | ||
csv_path, | ||
img_path_column, | ||
channel_column, | ||
out_key, | ||
spatial_dims, | ||
scene_column=scene_column, | ||
time_start_column=time_start_column, | ||
time_stop_column=time_stop_column, | ||
time_step_column=time_step_column, | ||
dict_meta=dict_meta, | ||
transform=transforms, | ||
) | ||
# currently only supports a 0/1 workers | ||
num_workers = min(dataloader_kwargs.pop("num_workers"), 1) | ||
return DataLoader(dataset, num_workers=num_workers, **dataloader_kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,33 @@ | ||
import re | ||
from typing import List | ||
|
||
import numpy as np | ||
from bioio import BioImage | ||
from monai.data import MetaTensor | ||
from monai.transforms import Transform | ||
|
||
from cyto_dl.utils.arg_checking import get_dtype | ||
|
||
class AICSImageLoaderd(Transform): | ||
|
||
class BioIOImageLoaderd(Transform): | ||
"""Enumerates scenes and timepoints for dictionary with format. | ||
|
||
{path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. Differs | ||
from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at | ||
initialization. | ||
{path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. | ||
Differs from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at | ||
initialization. The filepath will be saved in the dictionary as 'filename_or_obj' (with or without metadata depending on `include_meta_in_filename`). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
path_key: str = "path", | ||
scene_key: str = "scene", | ||
kwargs_keys: List = ["dimension_order_out", "C", "T"], | ||
resolution_key: str = "resolution", | ||
kwargs_keys: List[str] = ["dimension_order_out", "C", "T"], | ||
out_key: str = "raw", | ||
allow_missing_keys=False, | ||
dtype: np.dtype = np.float16, | ||
dask_load: bool = True, | ||
include_meta_in_filename: bool = False, | ||
): | ||
""" | ||
Parameters | ||
|
@@ -37,23 +42,36 @@ def __init__( | |
Key for the output image | ||
allow_missing_keys : bool = False | ||
Whether to allow missing keys in the data dictionary | ||
dtype : np.dtype = np.float16 | ||
Data type to cast the image to | ||
dask_load: bool = True | ||
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. | ||
include_meta_in_filename: bool = False | ||
Whether to include metadata in the filename. Useful when loading multi-dimensional images with different kwargs. | ||
""" | ||
super().__init__() | ||
self.path_key = path_key | ||
self.kwargs_keys = kwargs_keys | ||
self.allow_missing_keys = allow_missing_keys | ||
self.out_key = out_key | ||
self.resolution_key = resolution_key | ||
self.scene_key = scene_key | ||
self.dtype = dtype | ||
self.dtype = get_dtype(dtype) | ||
self.dask_load = dask_load | ||
self.include_meta_in_filename = include_meta_in_filename | ||
|
||
def split_args(self, arg): | ||
if "," in str(arg): | ||
if isinstance(arg, str) and "," in arg: | ||
return list(map(int, arg.split(","))) | ||
return arg | ||
|
||
def _get_filename(self, path, kwargs): | ||
if self.include_meta_in_filename: | ||
path = path.split(".")[0] + "_" + "_".join([f"{k}_{v}" for k, v in kwargs.items()]) | ||
# remove illegal characters from filename | ||
path = re.sub(r'[<>:"|?*]', "", path) | ||
return path | ||
|
||
def __call__(self, data): | ||
# copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time | ||
data = data.copy() | ||
|
@@ -63,12 +81,16 @@ def __call__(self, data): | |
img = BioImage(path) | ||
if self.scene_key in data: | ||
img.set_scene(data[self.scene_key]) | ||
if self.resolution_key in data: | ||
img.set_resolution_level(data[self.resolution_key]) | ||
kwargs = {k: self.split_args(data[k]) for k in self.kwargs_keys if k in data} | ||
if self.dask_load: | ||
img = img.get_image_dask_data(**kwargs).compute() | ||
else: | ||
img = img.get_image_data(**kwargs) | ||
img = img.astype(self.dtype) | ||
data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs}) | ||
|
||
kwargs.update({"filename_or_obj": self._get_filename(path, kwargs)}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. filename_or_obj us a monai thing right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes - we also use it for image saving |
||
if self.scene_key in data: | ||
kwargs["scene"] = data[self.scene_key] | ||
data[self.out_key] = MetaTensor(img, meta=kwargs) | ||
return data |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from hydra.utils import get_class | ||
from numpy.typing import DTypeLike | ||
|
||
|
||
def get_dtype(dtype: DTypeLike) -> DTypeLike: | ||
if isinstance(dtype, str): | ||
return get_class(dtype) | ||
elif dtype is None: | ||
return dtype | ||
elif isinstance(dtype, type): | ||
return dtype | ||
else: | ||
raise ValueError(f"Expected dtype to be DtypeLike, string, or None, got {type(dtype)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how this is different from DataframeDatamodule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be really useful to have an example function call or config for each dataset/datamodule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on the config. This is standardizing the creation of a dataframe from a multi-scene/multi-timepoint image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so youre saying dataframedaramodule does not work for multi-scene/multi-timepoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does if you enumerate all the scenes and timepoints for each multidim image in their own rows. Here, each row is just the multidim image path and the scenes/channels/timepoints you want to use