Skip to content

Commit

Permalink
Merge pull request #30 from Kainmueller-Lab/arctique
Browse files Browse the repository at this point in the history
Arctique
  • Loading branch information
ClaudiaWinklmayr authored Jan 20, 2025
2 parents 7b1e6ff + 641d8c8 commit 9fe04cd
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
130 changes: 130 additions & 0 deletions src/bio_image_datasets/arctique_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import numpy as np
from bio_image_datasets.dataset import Dataset
import skimage
from skimage.measure import label as relabel


# cell IDS from https://zenodo.org/records/14016860
mapping_dict = {
0: "Background",
1: "Epithelial",
2: "Plasma Cells",
3: "Lymphocytes",
4: "Eosinophils",
5: "Fibroblasts",
}



class ArctiqueDataset(Dataset):
def __init__(self, local_path):
"""
Initializes the ArctiqueDataset with the given local path.
The dataset is located on the /fast file system on the MDC cluster under the path
'/fast/AG_Kainmueller/data/patho_foundation_model_bench_data/arctique_dataset/arctique'.
Args:
local_path (str): Path to the directory containing the files.
"""
super().__init__(local_path)

print("LOCAL PATH", local_path)

self.images_folder = os.path.join(local_path, f'images/')
self.semantic_masks_folder = os.path.join(local_path, f'masks/semantic')
self.instance_masks_folder = os.path.join(local_path, f'masks/instance')
self.sample_IDs = [int(name.split("_")[1].split(".")[0]) for name in os.listdir(self.images_folder)]


def __len__(self):
"""Return the number of samples in the dataset."""
return len(self.sample_IDs)

def __getitem__(self, idx):
"""Return a sample as a dictionary at the given index.
Args:
idx (int): Index of the sample.
Returns:
dict: A dictionary containing the following keys:
- "image": Hematoxylin and eosin (HE) image
- "semantic_mask": Ground truth semantic mask
- "instance_mask": Ground truth instance mask
- "sample_name": Index of the sample as string
"""
if idx >= len(self):
raise IndexError("Index out of bounds.")

sample_ID = self.sample_IDs[idx]

data = {
"image": skimage.io.imread(os.path.join(self.images_folder, f"img_{sample_ID}.png")),
"semantic_mask": skimage.io.imread(os.path.join(self.semantic_masks_folder, f"{sample_ID}.png")),
"instance_mask": skimage.io.imread(os.path.join(self.instance_masks_folder, f"{sample_ID}.png")),
'sample_name': sample_ID
}
return data

def get_he(self, idx):
"""
Load the hematoxylin and eosin (HE) image for the given index.
Args:
idx (int): Index of the sample.
Returns:
np.ndarray: The HE image.
"""
sample_ID = self.sample_IDs[idx]
img = skimage.io.imread(os.path.join(self.images_folder, f"img_{sample_ID}.png"))
return img.transpose() # Transpose to have the channels first


def get_class_mapping(self):
"""Return the class mapping for the dataset.
Returns:
dict: A dictionary mapping class indices to class names.
"""
return mapping_dict

def get_instance_mask(self, idx):
"""Return the instance mask at the given index.
Args:
idx (int): Index of the sample.
Returns:
np.ndarray: The instance mask.
"""
sample_ID = self.sample_IDs[idx]
instance_mask = skimage.io.imread(os.path.join(self.instance_masks_folder, f"{sample_ID}.png"))
return instance_mask

def get_semantic_mask(self, idx):
"""Return the semantic mask at the given index.
Args:
idx (int): Index of the sample.
Returns:
np.ndarray: The semantic mask.
"""
sample_ID = self.sample_IDs[idx]
semantic_mask = skimage.io.imread(os.path.join(self.semantic_masks_folder, f"{sample_ID}.png"))
return semantic_mask

def get_sample_name(self, idx):
"""Return the sample name for the given index.
Args:
idx (int): Index of the sample.
Returns:
str: The sample name, consisting of the fold and local index, e.g. fold1_0
"""
sample_ID = self.sample_IDs[idx]
return sample_ID

def get_sample_names(self):
"""Return the list of all sample names."""
return self.sample_IDs

def __repr__(self):
"""Return the string representation of the dataset."""
return f"Arctique Dataset ({self.local_path}, {len(self)} samples)"
109 changes: 109 additions & 0 deletions tests/test_arctique_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from bio_image_datasets.arctique_dataset import ArctiqueDataset, mapping_dict
import os
import numpy as np
import tempfile
from PIL import Image


def prepare_arctique_samples(output_dir, num_samples=5):
"""
Creates mock files adhering to the Arctique dataset specifications.
Args:
output_dir (str): Path to the directory where the mock files will be created.
num_samples (int): Number of mock samples to create.
Returns:
list: List of paths to the created mock files.
"""
os.makedirs(output_dir, exist_ok=True)
for folder_name in ["images", "masks/semantic", "masks/instance"]:
os.makedirs(os.path.join(output_dir, folder_name), exist_ok=True)

for i in range(num_samples):
img_name = f"img_{i}.png"
tmp_img_path = os.path.join(output_dir, "images", img_name)
mock_image = Image.new('RGB', (512, 512), color=(255, 0, 0))
mock_image.save(tmp_img_path)

mask_name = f"{i}.png"
instances = np.random.randint(0, 100, size=(512, 512), dtype=np.uint16)
mock_mask_instances = Image.fromarray(instances)
mock_mask_instances.save(os.path.join(output_dir, "masks/instance", mask_name))

classes = np.random.randint(0, 6, size=(512, 512), dtype=np.uint16)
mock_mask_semantic = Image.fromarray(classes)
mock_mask_semantic.save(os.path.join(output_dir, "masks/semantic", mask_name))

def test_len():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path, num_samples=5)
dataset = ArctiqueDataset(local_path=local_path)
assert len(dataset) == 5


def test_getitem():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path)
dataset = ArctiqueDataset(local_path=local_path)
sample = dataset[0]
assert "image" in sample
assert "semantic_mask" in sample
assert "instance_mask" in sample
assert "sample_name" in sample


def test_get_he():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path)
dataset = ArctiqueDataset(local_path=local_path)
he_image = dataset.get_he(0)
assert he_image.shape == (3, 512, 512)


def test_get_class_mapping():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path)
dataset = ArctiqueDataset(local_path=local_path)
class_mapping = dataset.get_class_mapping()
assert class_mapping == mapping_dict


def test_get_instance_mask():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path)
dataset = ArctiqueDataset(local_path=local_path)
instace_mask = dataset.get_instance_mask(0)
assert instace_mask.shape == (512, 512)


def test_get_semantic_mask():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path)
dataset = ArctiqueDataset(local_path=local_path)
semantic_mask = dataset.get_semantic_mask(0)
assert semantic_mask.shape == (512, 512)


# def test_get_sample_name():
# with tempfile.TemporaryDirectory() as tmp_dir:
# local_path = os.path.join(tmp_dir)
# prepare_arctique_samples(local_path)
# dataset = ArctiqueDataset(local_path=local_path)
# sample_name = dataset.get_sample_name(0)
# assert type(sample_name) == "int"


def test_get_sample_names():
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = os.path.join(tmp_dir)
prepare_arctique_samples(local_path)
dataset = ArctiqueDataset(local_path=local_path)
sample_names = dataset.get_sample_names()
assert len(sample_names) == 5

0 comments on commit 9fe04cd

Please sign in to comment.