Skip to content

Commit

Permalink
Merge pull request #4 from juglab/ms/feat/model_adapter
Browse files Browse the repository at this point in the history
added model adapter as a way to use different models for feature extraction

Former-commit-id: 83afea1
  • Loading branch information
mese79 authored Jun 11, 2024
2 parents 254716f + 3e2a330 commit 689b031
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 132 deletions.
20 changes: 0 additions & 20 deletions src/featureforest/SAM/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion src/featureforest/SAM/models/MobileSAM/__init__.py

This file was deleted.

This file was deleted.

28 changes: 0 additions & 28 deletions src/featureforest/SAM/setup_model.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/featureforest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
__version__ = "0.0.2"

from ._embedding_extractor_widget import EmbeddingExtractorWidget
from ._sam_predictor_widget import SAMPredictorWidget
# from ._sam_predictor_widget import SAMPredictorWidget
from ._sam_rf_segmentation_widget import SAMRFSegmentationWidget
from ._sam_prompt_segmentation_widget import SAMPromptSegmentationWidget
# from ._sam_prompt_segmentation_widget import SAMPromptSegmentationWidget

__all__ = (
"EmbeddingExtractorWidget",
"SAMPredictorWidget",
# "SAMPredictorWidget",
"SAMRFSegmentationWidget",
"SAMPromptSegmentationWidget"
# "SAMPromptSegmentationWidget"
)
14 changes: 7 additions & 7 deletions src/featureforest/_embedding_extractor_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ScrollWidgetWrapper,
get_layer,
)
from . import SAM
from .models import MobileSAM
from .utils import (
config
)
Expand All @@ -27,7 +27,7 @@
get_patch_size
)
from .utils.extract import (
get_sam_embeddings_for_slice
get_slice_features
)


Expand Down Expand Up @@ -166,8 +166,8 @@ def extract_embeddings(self):
def get_stack_sam_embeddings(
self, image_layer, storage_path, patch_size, overlap
):
# initial sam model
sam_model, device = SAM.setup_mobile_sam_model()
# initial mobile-sam model
sam_model_adapter, device = MobileSAM.get_model(patch_size, overlap)
# initial storage hdf5 file
self.storage = h5py.File(storage_path, "w")
# get sam embeddings slice by slice and save them into storage file
Expand All @@ -179,13 +179,13 @@ def get_stack_sam_embeddings(
self.storage.attrs["overlap"] = overlap

for slice_index in np_progress(
range(num_slices), desc="get embeddings for slices"
range(num_slices), desc="extract features for slices"
):
image = image_layer.data[slice_index] if num_slices > 1 else image_layer.data
slice_grp = self.storage.create_group(str(slice_index))
get_sam_embeddings_for_slice(
get_slice_features(
image, patch_size, overlap,
sam_model.image_encoder, device, slice_grp
sam_model_adapter, device, slice_grp
)

yield (slice_index, num_slices)
Expand Down
18 changes: 9 additions & 9 deletions src/featureforest/_sam_rf_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
from sklearn.ensemble import RandomForestClassifier

from . import SAM
from .models import MobileSAM
from .widgets import (
ScrollWidgetWrapper,
get_layer,
Expand Down Expand Up @@ -48,16 +48,13 @@ def __init__(self, napari_viewer: napari.Viewer):
self.storage = None
self.rf_model = None
self.device = None
self.sam_model = None
self.feature_model = None
self.patch_size = 512
self.overlap = 384
self.stride = self.patch_size - self.overlap

self.prepare_widget()

# init sam model & predictor
self.sam_model, self.device = self.get_model_on_device()

def closeEvent(self, event):
print("closing")
self.viewer.layers.events.inserted.disconnect(self.check_input_layers)
Expand Down Expand Up @@ -365,6 +362,9 @@ def select_storage(self):
"overlap", self.overlap)
self.stride, _ = get_stride_margin(self.patch_size, self.overlap)

# init feature model
self.feature_model, self.device = self.get_model_on_device()

def add_labels_layer(self):
self.image_layer = get_layer(
self.viewer,
Expand Down Expand Up @@ -410,7 +410,7 @@ def analyze_labels(self):
self.each_class_label.setText("Labels per class:\n" + each_class)

def get_model_on_device(self):
return SAM.setup_mobile_sam_model()
return MobileSAM.get_model(self.patch_size, self.overlap)

def get_train_data(self):
# get ground truth class labels
Expand All @@ -423,7 +423,7 @@ def get_train_data(self):

num_slices, img_height, img_width = get_stack_dims(self.image_layer.data)
num_labels = sum([len(v) for v in labels_dict.values()])
total_channels = SAM.ENCODER_OUT_CHANNELS + SAM.EMBED_PATCH_CHANNELS
total_channels = self.feature_model.get_total_output_channels()
train_data = np.zeros((num_labels, total_channels))
labels = np.zeros(num_labels, dtype="int32") - 1
count = 0
Expand Down Expand Up @@ -602,7 +602,7 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width):
# shape: N x target_size x target_size x C
feature_patches = self.storage[str(slice_index)]["sam"][:]
num_patches = feature_patches.shape[0]
total_channels = SAM.ENCODER_OUT_CHANNELS + SAM.EMBED_PATCH_CHANNELS
total_channels = self.feature_model.get_total_output_channels()
for i in np_progress(range(num_patches), desc="Predicting slice patches"):
input_data = feature_patches[i].reshape(-1, total_channels)
predictions = rf_model.predict(input_data).astype(np.uint8)
Expand Down Expand Up @@ -633,7 +633,7 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width):
area_threshold = float(self.area_threshold_textbox.text()) / 100
if self.sam_post_checkbox.checkState() == Qt.Checked:
segmentation_image = postprocess_segmentations_with_sam(
self.sam_model, segmentation_image, area_threshold
self.feature_model, segmentation_image, area_threshold
)
else:
segmentation_image = postprocess_segmentation(
Expand Down
2 changes: 2 additions & 0 deletions src/featureforest/models/MobileSAM/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model import get_model
from .adapter import MobileSAMAdapter
52 changes: 52 additions & 0 deletions src/featureforest/models/MobileSAM/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torchvision.transforms import v2 as tv_transforms2

from featureforest.models.base import BaseModelAdapter
from featureforest.utils.data import (
get_nonoverlapped_patches,
)


class MobileSAMAdapter(BaseModelAdapter):
"""MobileSAM model adapter
"""
def __init__(
self,
model: nn.Module,
input_transforms: tv_transforms2.Compose,
patch_size: int,
overlap: int,
) -> None:
super().__init__(model, input_transforms, patch_size, overlap)
# we need sam image encoder part
self.encoder = self.model.image_encoder
self.encoder_num_channels = 256
self.embed_layer_num_channels = 64

def get_features_patches(
self, in_patches: Tensor
) -> Tuple[Tensor, Tensor]:
# get the mobile-sam encoder and embedding layer outputs
with torch.no_grad():
output, embed_output, _ = self.encoder(
self.input_transforms(in_patches)
)

# get non-overlapped feature patches
out_feature_patches = get_nonoverlapped_patches(
self.embedding_transform(output.cpu()),
self.patch_size, self.overlap
)
embed_feature_patches = get_nonoverlapped_patches(
self.embedding_transform(embed_output.cpu()),
self.patch_size, self.overlap
)

return out_feature_patches, embed_feature_patches

def get_total_output_channels(self) -> int:
return self.encoder_num_channels + self.embed_layer_num_channels
Original file line number Diff line number Diff line change
@@ -1,8 +1,50 @@
import torch
from torchvision.transforms import v2 as tv_transforms2

from .tiny_vit_sam import TinyViT
from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer

from featureforest.utils.downloader import download_model
from .adapter import MobileSAMAdapter


def get_model(patch_size: int, overlap: int):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
# get the model
model = setup_model().to(device)
# download model's weights
model_url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
model_file = download_model(
model_url=model_url,
model_name="mobile_sam.pt"
)
if model_file is None:
raise ValueError(f"Could not download the model from {model_url}.")

# load weights
weights = torch.load(model_file, map_location=device)
model.load_state_dict(weights, strict=True)
model.eval()

def setup_model():
# input transform for sam
sam_input_dim = 1024
input_transforms = tv_transforms2.Compose([
tv_transforms2.Resize(
(sam_input_dim, sam_input_dim),
interpolation=tv_transforms2.InterpolationMode.BICUBIC,
antialias=True
),
])
# create the model adapter
sam_model_adapter = MobileSAMAdapter(
model, input_transforms, patch_size, overlap
)

return sam_model_adapter, device


def setup_model() -> Sam:
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
Expand Down Expand Up @@ -43,4 +85,5 @@ def setup_model():
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)

return mobile_sam
73 changes: 73 additions & 0 deletions src/featureforest/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.transforms import v2 as tv_transforms2

from ..utils.data import (
get_nonoverlapped_patches,
)


class BaseModelAdapter:
"""Base class for adapting any models in featureforest.
"""
def __init__(
self,
model: nn.Module,
input_transforms: tv_transforms2.Compose,
patch_size: int,
overlap: int,
) -> None:
"""Initialization function
Args:
model (nn.Module): the pytorch model (e.g. a ViT encoder)
input_transforms (tv_transforms2.Compose): input transformations for the specific model
patch_size (int): input patch size
overlap (int): input patch overlap
"""
self.model = model
self.input_transforms = input_transforms
self.patch_size = patch_size
self.overlap = overlap
# to transform feature patches to the original patch size
self.embedding_transform = tv_transforms2.Compose([
tv_transforms2.Resize(
(self.patch_size, self.patch_size),
interpolation=tv_transforms2.InterpolationMode.BICUBIC,
antialias=True
),
])

def get_features_patches(
self, in_patches: Tensor
) -> Tensor:
"""Returns a tensor of model's extracted features.
This function is more like an abstract function, and should be overridden.
Args:
in_patches (Tensor): input patches
Returns:
Tensor: model's extracted features
"""
# get the model output
with torch.no_grad():
out_features = self.model(self.input_transforms(in_patches))
# assert self.patch_size == out_features.shape[-1]

# get non-overlapped feature patches
feature_patches = get_nonoverlapped_patches(
self.embedding_transform(out_features.cpu()),
self.patch_size, self.overlap
)

return feature_patches

def get_total_output_channels(self) -> int:
"""Returns total number of model output channels (a.k.a. number of feature maps).
Returns:
int: total number of output channels
"""
return 256
Loading

0 comments on commit 689b031

Please sign in to comment.