Skip to content

Commit

Permalink
Merge pull request #27 from Kainmueller-Lab/pannuke
Browse files Browse the repository at this point in the history
change get name in pannuke
  • Loading branch information
JLrumberger authored Dec 9, 2024
2 parents a60d7f7 + be94eb6 commit 65789c6
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 3 deletions.
179 changes: 179 additions & 0 deletions src/bio_image_datasets/pannuke_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import numpy as np
from bio_image_datasets.dataset import Dataset
from skimage.measure import label as relabel

mapping_dict = {
0: "Background",
1: "Epithelial",
2: "Dead Cells",
3: "Connective/Soft tissue cells",
4: "Inflammatory",
5: "Neoplastic cells",
}


class PanNukeDataset(Dataset):
def __init__(self, local_path):
"""
Initializes the PanNukeDataset with the given local path and optional transform.
The three PanNuke datasets are located on the /fast file system on the MDC cluster under the paths
'/fast/AG_Kainmueller/data/pannuke/fold1', '/fast/AG_Kainmueller/data/pannuke/fold2', and '/fast/AG_Kainmueller/data/pannuke/fold3'.
Args:
local_path (str): Path to the directory containing the files.
"""
# get fold number from local_path
if 'fold1' in local_path:
self.fold = 1
elif 'fold2' in local_path:
self.fold = 2
elif 'fold3' in local_path:
self.fold = 3
else:
raise ValueError("Invalid local path.")

super().__init__(local_path)
self.images_file = os.path.join(local_path, f'images/fold{self.fold}/images.npy')
self.types_file = os.path.join(local_path, f'images/fold{self.fold}/types.npy')
self.masks_file = os.path.join(local_path, f'masks/fold{self.fold}/masks.npy')

if not self.images_file:
raise ValueError("No images file found in the specified directory.")

if not self.types_file:
raise ValueError("No types file found in the specified directory.")

if not self.masks_file:
raise ValueError("No masks file found in the specified directory.")

# put last channel in images to first
self.images = np.load(self.images_file)
self.images = np.moveaxis(self.images, -1, 1)
self.types = np.load(self.types_file)
self.masks = np.load(self.masks_file)

self.semantic_masks = self.prepare_semantic_masks(self.masks)
self.instance_masks = self.prepare_instance_masks(self.masks)


def __len__(self):
"""Return the number of samples in the dataset."""
return np.shape(self.types)[0]

def __getitem__(self, idx):
"""Return a sample as a dictionary at the given index.
Args:
idx (int): Index of the sample.
Returns:
dict: A dictionary containing the following keys:
- "image": Hematoxylin and eosin (HE) image
- "type": Tissue type where the sample comes from
- "semantic_mask": Ground truth semantic mask
- "instance_mask": Ground truth instance mask
- "sample_name": Index of the sample as string
"""
if idx >= len(self):
raise IndexError("Index out of bounds.")

data = {
"image": self.images[idx],
"type": self.types[idx],
"semantic_mask": self.semantic_masks[idx],
"instance_mask": self.instance_masks[idx],
'sample_name': self.get_sample_name(idx)
}
return data

def get_he(self, idx):
"""
Load the hematoxylin and eosin (HE) image for the given index.
Args:
idx (int): Index of the sample.
Returns:
np.ndarray: The HE image.
"""
return self.images[idx]

def get_tissue_type(self, idx):
"""Load tissue type.
Args:
idx (int): Index of the sample.
Returns:
tissue type (string): The tissue type.
"""
return self.types[idx]

def get_class_mapping(self):
"""Return the class mapping for the dataset.
Returns:
dict: A dictionary mapping class indices to class names.
"""
return mapping_dict

def get_instance_mask(self, idx):
"""Return the instance mask at the given index.
Args:
idx (int): Index of the sample.
Returns:
np.ndarray: The instance mask.
"""
return self.instance_masks[idx]

def get_semantic_mask(self, idx):
"""Return the semantic mask at the given index.
Args:
idx (int): Index of the sample.
Returns:
np.ndarray: The semantic mask.
"""
return self.semantic_masks[idx]


def get_sample_name(self, idx):
"""Return the sample name for the given index.
Args:
idx (int): Index of the sample.
Returns:
str: The sample name.
"""
return f"fold{self.fold}_{str(idx)}"

def get_sample_names(self):
"""Return the list of all sample names."""
return [f"fold{self.fold}_{str(i)}" for i in range(int(len(self)))]

def __repr__(self):
"""Return the string representation of the dataset."""
return f"PanNuke Dataset ({self.local_path}, {len(self)} samples)"

def prepare_semantic_masks(self, masks):
""" Prepare the semantic segmentation mask based on the pannuke mask which contains each semantic class
in an individual channel which need to be collapsed via argmax
Args:
masks (np.array): B x H x W x C
"""
# reverse order of last dim of masks to adhere to class mapping specified in mapping_dict
masks = np.flip(masks, axis=-1)
semantic_masks = np.argmax(masks, -1)
return semantic_masks

def prepare_instance_masks(self, masks):
""" Prepare the instance segmentation mask based on the pannuke mask which contains each semantic class
in an individual channel which need to be collapsed via max
Args:
masks (np.array): B x H x W x C
"""
# reverse order of last dim of masks to adhere to class mapping specified in mapping_dict
instance_masks = np.max(masks[..., :-1], -1)
# iterate over first dimension and use relabel on each individual tile
for i in range(instance_masks.shape[0]):
instance_masks[i] = relabel(instance_masks[i], background=0)
return instance_masks
6 changes: 3 additions & 3 deletions tests/test_pannuke_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bio_image_datasets.pannuke import PanNukeDataset, mapping_dict
from bio_image_datasets.pannuke_dataset import PanNukeDataset, mapping_dict
import os
import numpy as np
import tempfile
Expand Down Expand Up @@ -112,12 +112,12 @@ def test_get_sample_name():
file_paths = prepare_pannuke_samples(local_path, num_samples=5)
dataset = PanNukeDataset(local_path=local_path)
sample_name = dataset.get_sample_name(0)
assert sample_name == "0"
assert sample_name == "fold1_0"

def test_get_sample_names():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir, 'fold1')
file_paths = prepare_pannuke_samples(local_path, num_samples=5)
dataset = PanNukeDataset(local_path=local_path)
sample_names = dataset.get_sample_names()
assert sample_names == ["0", "1", "2", "3", "4"]
assert sample_names == ["fold1_0", "fold1_1", "fold1_2", "fold1_3", "fold1_4"]

0 comments on commit 65789c6

Please sign in to comment.