-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from Kainmueller-Lab/arctique
Arctique
- Loading branch information
Showing
2 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
import numpy as np | ||
from bio_image_datasets.dataset import Dataset | ||
import skimage | ||
from skimage.measure import label as relabel | ||
|
||
|
||
# cell IDS from https://zenodo.org/records/14016860 | ||
mapping_dict = { | ||
0: "Background", | ||
1: "Epithelial", | ||
2: "Plasma Cells", | ||
3: "Lymphocytes", | ||
4: "Eosinophils", | ||
5: "Fibroblasts", | ||
} | ||
|
||
|
||
|
||
class ArctiqueDataset(Dataset): | ||
def __init__(self, local_path): | ||
""" | ||
Initializes the ArctiqueDataset with the given local path. | ||
The dataset is located on the /fast file system on the MDC cluster under the path | ||
'/fast/AG_Kainmueller/data/patho_foundation_model_bench_data/arctique_dataset/arctique'. | ||
Args: | ||
local_path (str): Path to the directory containing the files. | ||
""" | ||
super().__init__(local_path) | ||
|
||
print("LOCAL PATH", local_path) | ||
|
||
self.images_folder = os.path.join(local_path, f'images/') | ||
self.semantic_masks_folder = os.path.join(local_path, f'masks/semantic') | ||
self.instance_masks_folder = os.path.join(local_path, f'masks/instance') | ||
self.sample_IDs = [int(name.split("_")[1].split(".")[0]) for name in os.listdir(self.images_folder)] | ||
|
||
|
||
def __len__(self): | ||
"""Return the number of samples in the dataset.""" | ||
return len(self.sample_IDs) | ||
|
||
def __getitem__(self, idx): | ||
"""Return a sample as a dictionary at the given index. | ||
Args: | ||
idx (int): Index of the sample. | ||
Returns: | ||
dict: A dictionary containing the following keys: | ||
- "image": Hematoxylin and eosin (HE) image | ||
- "semantic_mask": Ground truth semantic mask | ||
- "instance_mask": Ground truth instance mask | ||
- "sample_name": Index of the sample as string | ||
""" | ||
if idx >= len(self): | ||
raise IndexError("Index out of bounds.") | ||
|
||
sample_ID = self.sample_IDs[idx] | ||
|
||
data = { | ||
"image": skimage.io.imread(os.path.join(self.images_folder, f"img_{sample_ID}.png")), | ||
"semantic_mask": skimage.io.imread(os.path.join(self.semantic_masks_folder, f"{sample_ID}.png")), | ||
"instance_mask": skimage.io.imread(os.path.join(self.instance_masks_folder, f"{sample_ID}.png")), | ||
'sample_name': sample_ID | ||
} | ||
return data | ||
|
||
def get_he(self, idx): | ||
""" | ||
Load the hematoxylin and eosin (HE) image for the given index. | ||
Args: | ||
idx (int): Index of the sample. | ||
Returns: | ||
np.ndarray: The HE image. | ||
""" | ||
sample_ID = self.sample_IDs[idx] | ||
img = skimage.io.imread(os.path.join(self.images_folder, f"img_{sample_ID}.png")) | ||
return img.transpose() # Transpose to have the channels first | ||
|
||
|
||
def get_class_mapping(self): | ||
"""Return the class mapping for the dataset. | ||
Returns: | ||
dict: A dictionary mapping class indices to class names. | ||
""" | ||
return mapping_dict | ||
|
||
def get_instance_mask(self, idx): | ||
"""Return the instance mask at the given index. | ||
Args: | ||
idx (int): Index of the sample. | ||
Returns: | ||
np.ndarray: The instance mask. | ||
""" | ||
sample_ID = self.sample_IDs[idx] | ||
instance_mask = skimage.io.imread(os.path.join(self.instance_masks_folder, f"{sample_ID}.png")) | ||
return instance_mask | ||
|
||
def get_semantic_mask(self, idx): | ||
"""Return the semantic mask at the given index. | ||
Args: | ||
idx (int): Index of the sample. | ||
Returns: | ||
np.ndarray: The semantic mask. | ||
""" | ||
sample_ID = self.sample_IDs[idx] | ||
semantic_mask = skimage.io.imread(os.path.join(self.semantic_masks_folder, f"{sample_ID}.png")) | ||
return semantic_mask | ||
|
||
def get_sample_name(self, idx): | ||
"""Return the sample name for the given index. | ||
Args: | ||
idx (int): Index of the sample. | ||
Returns: | ||
str: The sample name, consisting of the fold and local index, e.g. fold1_0 | ||
""" | ||
sample_ID = self.sample_IDs[idx] | ||
return sample_ID | ||
|
||
def get_sample_names(self): | ||
"""Return the list of all sample names.""" | ||
return self.sample_IDs | ||
|
||
def __repr__(self): | ||
"""Return the string representation of the dataset.""" | ||
return f"Arctique Dataset ({self.local_path}, {len(self)} samples)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from bio_image_datasets.arctique_dataset import ArctiqueDataset, mapping_dict | ||
import os | ||
import numpy as np | ||
import tempfile | ||
from PIL import Image | ||
|
||
|
||
def prepare_arctique_samples(output_dir, num_samples=5): | ||
""" | ||
Creates mock files adhering to the Arctique dataset specifications. | ||
Args: | ||
output_dir (str): Path to the directory where the mock files will be created. | ||
num_samples (int): Number of mock samples to create. | ||
Returns: | ||
list: List of paths to the created mock files. | ||
""" | ||
os.makedirs(output_dir, exist_ok=True) | ||
for folder_name in ["images", "masks/semantic", "masks/instance"]: | ||
os.makedirs(os.path.join(output_dir, folder_name), exist_ok=True) | ||
|
||
for i in range(num_samples): | ||
img_name = f"img_{i}.png" | ||
tmp_img_path = os.path.join(output_dir, "images", img_name) | ||
mock_image = Image.new('RGB', (512, 512), color=(255, 0, 0)) | ||
mock_image.save(tmp_img_path) | ||
|
||
mask_name = f"{i}.png" | ||
instances = np.random.randint(0, 100, size=(512, 512), dtype=np.uint16) | ||
mock_mask_instances = Image.fromarray(instances) | ||
mock_mask_instances.save(os.path.join(output_dir, "masks/instance", mask_name)) | ||
|
||
classes = np.random.randint(0, 6, size=(512, 512), dtype=np.uint16) | ||
mock_mask_semantic = Image.fromarray(classes) | ||
mock_mask_semantic.save(os.path.join(output_dir, "masks/semantic", mask_name)) | ||
|
||
def test_len(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path, num_samples=5) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
assert len(dataset) == 5 | ||
|
||
|
||
def test_getitem(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
sample = dataset[0] | ||
assert "image" in sample | ||
assert "semantic_mask" in sample | ||
assert "instance_mask" in sample | ||
assert "sample_name" in sample | ||
|
||
|
||
def test_get_he(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
he_image = dataset.get_he(0) | ||
assert he_image.shape == (3, 512, 512) | ||
|
||
|
||
def test_get_class_mapping(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
class_mapping = dataset.get_class_mapping() | ||
assert class_mapping == mapping_dict | ||
|
||
|
||
def test_get_instance_mask(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
instace_mask = dataset.get_instance_mask(0) | ||
assert instace_mask.shape == (512, 512) | ||
|
||
|
||
def test_get_semantic_mask(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
semantic_mask = dataset.get_semantic_mask(0) | ||
assert semantic_mask.shape == (512, 512) | ||
|
||
|
||
# def test_get_sample_name(): | ||
# with tempfile.TemporaryDirectory() as tmp_dir: | ||
# local_path = os.path.join(tmp_dir) | ||
# prepare_arctique_samples(local_path) | ||
# dataset = ArctiqueDataset(local_path=local_path) | ||
# sample_name = dataset.get_sample_name(0) | ||
# assert type(sample_name) == "int" | ||
|
||
|
||
def test_get_sample_names(): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_path = os.path.join(tmp_dir) | ||
prepare_arctique_samples(local_path) | ||
dataset = ArctiqueDataset(local_path=local_path) | ||
sample_names = dataset.get_sample_names() | ||
assert len(sample_names) == 5 |