diff --git a/src/featureforest/SAM/__init__.py b/src/featureforest/SAM/__init__.py deleted file mode 100644 index 03b068e..0000000 --- a/src/featureforest/SAM/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from torchvision import transforms - -from featureforest.SAM.setup_model import setup_mobile_sam_model -from segment_anything import SamPredictor - - -INPUT_SIZE = 1024 -FEATURE_H = FEATURE_W = 64 -ENCODER_OUT_CHANNELS = 256 -PATCH_SIZE = 256 -EMBED_PATCH_CHANNELS = 64 - - -sam_transform = transforms.Compose([ - transforms.Resize( - (INPUT_SIZE, INPUT_SIZE), - interpolation=transforms.InterpolationMode.BICUBIC, - antialias=True - ), -]) diff --git a/src/featureforest/SAM/models/MobileSAM/__init__.py b/src/featureforest/SAM/models/MobileSAM/__init__.py deleted file mode 100644 index fcd9f29..0000000 --- a/src/featureforest/SAM/models/MobileSAM/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .setup_mobile_sam import setup_model diff --git a/src/featureforest/SAM/models/weights/mobile_sam.pt.REMOVED.git-id b/src/featureforest/SAM/models/weights/mobile_sam.pt.REMOVED.git-id deleted file mode 100644 index e7e7dc5..0000000 --- a/src/featureforest/SAM/models/weights/mobile_sam.pt.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -7ef2d090979fd9853adf17b7a99c8b94a1c5a6a7 \ No newline at end of file diff --git a/src/featureforest/SAM/setup_model.py b/src/featureforest/SAM/setup_model.py deleted file mode 100644 index 1e33d9a..0000000 --- a/src/featureforest/SAM/setup_model.py +++ /dev/null @@ -1,28 +0,0 @@ -from pathlib import Path - -import torch - -from ..utils.downloader import download_model -from .models import MobileSAM - - -def setup_mobile_sam_model(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"running on {device}") - # sam model (light hq sam) - model = MobileSAM.setup_model().to(device) - # download model's weights - model_url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" - model_file = download_model( - model_url=model_url, - model_name="mobile_sam.pt" - ) - if model_file is None: - raise ValueError(f"Could not download the model from {model_url}.") - - # load weights - weights = torch.load(model_file, map_location=device) - model.load_state_dict(weights, strict=True) - model.eval() - - return model, device diff --git a/src/featureforest/__init__.py b/src/featureforest/__init__.py index 03ced45..27046b3 100644 --- a/src/featureforest/__init__.py +++ b/src/featureforest/__init__.py @@ -1,13 +1,13 @@ __version__ = "0.0.2" from ._embedding_extractor_widget import EmbeddingExtractorWidget -from ._sam_predictor_widget import SAMPredictorWidget +# from ._sam_predictor_widget import SAMPredictorWidget from ._sam_rf_segmentation_widget import SAMRFSegmentationWidget -from ._sam_prompt_segmentation_widget import SAMPromptSegmentationWidget +# from ._sam_prompt_segmentation_widget import SAMPromptSegmentationWidget __all__ = ( "EmbeddingExtractorWidget", - "SAMPredictorWidget", + # "SAMPredictorWidget", "SAMRFSegmentationWidget", - "SAMPromptSegmentationWidget" + # "SAMPromptSegmentationWidget" ) diff --git a/src/featureforest/_embedding_extractor_widget.py b/src/featureforest/_embedding_extractor_widget.py index 2d4c1b5..4dccd4a 100644 --- a/src/featureforest/_embedding_extractor_widget.py +++ b/src/featureforest/_embedding_extractor_widget.py @@ -18,7 +18,7 @@ ScrollWidgetWrapper, get_layer, ) -from . import SAM +from .models import MobileSAM from .utils import ( config ) @@ -27,7 +27,7 @@ get_patch_size ) from .utils.extract import ( - get_sam_embeddings_for_slice + get_slice_features ) @@ -166,8 +166,8 @@ def extract_embeddings(self): def get_stack_sam_embeddings( self, image_layer, storage_path, patch_size, overlap ): - # initial sam model - sam_model, device = SAM.setup_mobile_sam_model() + # initial mobile-sam model + sam_model_adapter, device = MobileSAM.get_model(patch_size, overlap) # initial storage hdf5 file self.storage = h5py.File(storage_path, "w") # get sam embeddings slice by slice and save them into storage file @@ -179,13 +179,13 @@ def get_stack_sam_embeddings( self.storage.attrs["overlap"] = overlap for slice_index in np_progress( - range(num_slices), desc="get embeddings for slices" + range(num_slices), desc="extract features for slices" ): image = image_layer.data[slice_index] if num_slices > 1 else image_layer.data slice_grp = self.storage.create_group(str(slice_index)) - get_sam_embeddings_for_slice( + get_slice_features( image, patch_size, overlap, - sam_model.image_encoder, device, slice_grp + sam_model_adapter, device, slice_grp ) yield (slice_index, num_slices) diff --git a/src/featureforest/_sam_rf_segmentation_widget.py b/src/featureforest/_sam_rf_segmentation_widget.py index b59fc27..c396e93 100644 --- a/src/featureforest/_sam_rf_segmentation_widget.py +++ b/src/featureforest/_sam_rf_segmentation_widget.py @@ -20,7 +20,7 @@ import numpy as np from sklearn.ensemble import RandomForestClassifier -from . import SAM +from .models import MobileSAM from .widgets import ( ScrollWidgetWrapper, get_layer, @@ -48,16 +48,13 @@ def __init__(self, napari_viewer: napari.Viewer): self.storage = None self.rf_model = None self.device = None - self.sam_model = None + self.feature_model = None self.patch_size = 512 self.overlap = 384 self.stride = self.patch_size - self.overlap self.prepare_widget() - # init sam model & predictor - self.sam_model, self.device = self.get_model_on_device() - def closeEvent(self, event): print("closing") self.viewer.layers.events.inserted.disconnect(self.check_input_layers) @@ -365,6 +362,9 @@ def select_storage(self): "overlap", self.overlap) self.stride, _ = get_stride_margin(self.patch_size, self.overlap) + # init feature model + self.feature_model, self.device = self.get_model_on_device() + def add_labels_layer(self): self.image_layer = get_layer( self.viewer, @@ -410,7 +410,7 @@ def analyze_labels(self): self.each_class_label.setText("Labels per class:\n" + each_class) def get_model_on_device(self): - return SAM.setup_mobile_sam_model() + return MobileSAM.get_model(self.patch_size, self.overlap) def get_train_data(self): # get ground truth class labels @@ -423,7 +423,7 @@ def get_train_data(self): num_slices, img_height, img_width = get_stack_dims(self.image_layer.data) num_labels = sum([len(v) for v in labels_dict.values()]) - total_channels = SAM.ENCODER_OUT_CHANNELS + SAM.EMBED_PATCH_CHANNELS + total_channels = self.feature_model.get_total_output_channels() train_data = np.zeros((num_labels, total_channels)) labels = np.zeros(num_labels, dtype="int32") - 1 count = 0 @@ -602,7 +602,7 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width): # shape: N x target_size x target_size x C feature_patches = self.storage[str(slice_index)]["sam"][:] num_patches = feature_patches.shape[0] - total_channels = SAM.ENCODER_OUT_CHANNELS + SAM.EMBED_PATCH_CHANNELS + total_channels = self.feature_model.get_total_output_channels() for i in np_progress(range(num_patches), desc="Predicting slice patches"): input_data = feature_patches[i].reshape(-1, total_channels) predictions = rf_model.predict(input_data).astype(np.uint8) @@ -633,7 +633,7 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width): area_threshold = float(self.area_threshold_textbox.text()) / 100 if self.sam_post_checkbox.checkState() == Qt.Checked: segmentation_image = postprocess_segmentations_with_sam( - self.sam_model, segmentation_image, area_threshold + self.feature_model, segmentation_image, area_threshold ) else: segmentation_image = postprocess_segmentation( diff --git a/src/featureforest/models/MobileSAM/__init__.py b/src/featureforest/models/MobileSAM/__init__.py new file mode 100644 index 0000000..efb22c8 --- /dev/null +++ b/src/featureforest/models/MobileSAM/__init__.py @@ -0,0 +1,2 @@ +from .model import get_model +from .adapter import MobileSAMAdapter diff --git a/src/featureforest/models/MobileSAM/adapter.py b/src/featureforest/models/MobileSAM/adapter.py new file mode 100644 index 0000000..9a8fc4e --- /dev/null +++ b/src/featureforest/models/MobileSAM/adapter.py @@ -0,0 +1,52 @@ +from typing import Tuple + +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.transforms import v2 as tv_transforms2 + +from featureforest.models.base import BaseModelAdapter +from featureforest.utils.data import ( + get_nonoverlapped_patches, +) + + +class MobileSAMAdapter(BaseModelAdapter): + """MobileSAM model adapter + """ + def __init__( + self, + model: nn.Module, + input_transforms: tv_transforms2.Compose, + patch_size: int, + overlap: int, + ) -> None: + super().__init__(model, input_transforms, patch_size, overlap) + # we need sam image encoder part + self.encoder = self.model.image_encoder + self.encoder_num_channels = 256 + self.embed_layer_num_channels = 64 + + def get_features_patches( + self, in_patches: Tensor + ) -> Tuple[Tensor, Tensor]: + # get the mobile-sam encoder and embedding layer outputs + with torch.no_grad(): + output, embed_output, _ = self.encoder( + self.input_transforms(in_patches) + ) + + # get non-overlapped feature patches + out_feature_patches = get_nonoverlapped_patches( + self.embedding_transform(output.cpu()), + self.patch_size, self.overlap + ) + embed_feature_patches = get_nonoverlapped_patches( + self.embedding_transform(embed_output.cpu()), + self.patch_size, self.overlap + ) + + return out_feature_patches, embed_feature_patches + + def get_total_output_channels(self) -> int: + return self.encoder_num_channels + self.embed_layer_num_channels diff --git a/src/featureforest/SAM/models/MobileSAM/setup_mobile_sam.py b/src/featureforest/models/MobileSAM/model.py similarity index 52% rename from src/featureforest/SAM/models/MobileSAM/setup_mobile_sam.py rename to src/featureforest/models/MobileSAM/model.py index 50d9186..ad8074d 100644 --- a/src/featureforest/SAM/models/MobileSAM/setup_mobile_sam.py +++ b/src/featureforest/models/MobileSAM/model.py @@ -1,8 +1,50 @@ +import torch +from torchvision.transforms import v2 as tv_transforms2 + from .tiny_vit_sam import TinyViT from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer +from featureforest.utils.downloader import download_model +from .adapter import MobileSAMAdapter + + +def get_model(patch_size: int, overlap: int): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"running on {device}") + # get the model + model = setup_model().to(device) + # download model's weights + model_url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" + model_file = download_model( + model_url=model_url, + model_name="mobile_sam.pt" + ) + if model_file is None: + raise ValueError(f"Could not download the model from {model_url}.") + + # load weights + weights = torch.load(model_file, map_location=device) + model.load_state_dict(weights, strict=True) + model.eval() -def setup_model(): + # input transform for sam + sam_input_dim = 1024 + input_transforms = tv_transforms2.Compose([ + tv_transforms2.Resize( + (sam_input_dim, sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True + ), + ]) + # create the model adapter + sam_model_adapter = MobileSAMAdapter( + model, input_transforms, patch_size, overlap + ) + + return sam_model_adapter, device + + +def setup_model() -> Sam: prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 @@ -43,4 +85,5 @@ def setup_model(): pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) + return mobile_sam diff --git a/src/featureforest/SAM/models/MobileSAM/tiny_vit_sam.py b/src/featureforest/models/MobileSAM/tiny_vit_sam.py similarity index 100% rename from src/featureforest/SAM/models/MobileSAM/tiny_vit_sam.py rename to src/featureforest/models/MobileSAM/tiny_vit_sam.py diff --git a/src/featureforest/models/base.py b/src/featureforest/models/base.py new file mode 100644 index 0000000..2749ef4 --- /dev/null +++ b/src/featureforest/models/base.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.transforms import v2 as tv_transforms2 + +from ..utils.data import ( + get_nonoverlapped_patches, +) + + +class BaseModelAdapter: + """Base class for adapting any models in featureforest. + """ + def __init__( + self, + model: nn.Module, + input_transforms: tv_transforms2.Compose, + patch_size: int, + overlap: int, + ) -> None: + """Initialization function + + Args: + model (nn.Module): the pytorch model (e.g. a ViT encoder) + input_transforms (tv_transforms2.Compose): input transformations for the specific model + patch_size (int): input patch size + overlap (int): input patch overlap + """ + self.model = model + self.input_transforms = input_transforms + self.patch_size = patch_size + self.overlap = overlap + # to transform feature patches to the original patch size + self.embedding_transform = tv_transforms2.Compose([ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True + ), + ]) + + def get_features_patches( + self, in_patches: Tensor + ) -> Tensor: + """Returns a tensor of model's extracted features. + This function is more like an abstract function, and should be overridden. + + Args: + in_patches (Tensor): input patches + + Returns: + Tensor: model's extracted features + """ + # get the model output + with torch.no_grad(): + out_features = self.model(self.input_transforms(in_patches)) + # assert self.patch_size == out_features.shape[-1] + + # get non-overlapped feature patches + feature_patches = get_nonoverlapped_patches( + self.embedding_transform(out_features.cpu()), + self.patch_size, self.overlap + ) + + return feature_patches + + def get_total_output_channels(self) -> int: + """Returns total number of model output channels (a.k.a. number of feature maps). + + Returns: + int: total number of output channels + """ + return 256 diff --git a/src/featureforest/napari.yaml b/src/featureforest/napari.yaml index 2cc2de1..889218e 100644 --- a/src/featureforest/napari.yaml +++ b/src/featureforest/napari.yaml @@ -11,18 +11,18 @@ contributions: - id: featureforest.make_rf_segmentation_widget python_name: featureforest:SAMRFSegmentationWidget title: Make SAM-RF segmentation widget - - id: featureforest.make_prompt_segmentation_widget - python_name: featureforest:SAMPromptSegmentationWidget - title: Make SAM Prompt segmentation widget - - id: featureforest.make_predictor_widget - python_name: featureforest:SAMPredictorWidget - title: Make SAM predictor widget + # - id: featureforest.make_prompt_segmentation_widget + # python_name: featureforest:SAMPromptSegmentationWidget + # title: Make SAM Prompt segmentation widget + # - id: featureforest.make_predictor_widget + # python_name: featureforest:SAMPredictorWidget + # title: Make SAM predictor widget widgets: - command: featureforest.make_extractor_widget display_name: SAM Embedding Extractor - command: featureforest.make_rf_segmentation_widget display_name: SAM-RF Segmentation Widget - - command: featureforest.make_prompt_segmentation_widget - display_name: SAM Prompt Segmentation Widget - - command: featureforest.make_predictor_widget - display_name: SAM Predictor Widget + # - command: featureforest.make_prompt_segmentation_widget + # display_name: SAM Prompt Segmentation Widget + # - command: featureforest.make_predictor_widget + # display_name: SAM Predictor Widget diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index e02a441..bd2ff8b 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -134,7 +134,7 @@ def get_num_patches( return int(num_patches_h), int(num_patches_w) -def get_nonoverlap_patches(patches: Tensor, patch_size: int, overlap: int) -> Tensor: +def get_nonoverlapped_patches(patches: Tensor, patch_size: int, overlap: int) -> Tensor: """Extracts and returns non-overlap patches from patches with overlap. Args: diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index b895399..9232639 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -3,24 +3,32 @@ import numpy as np import h5py import torch -from torchvision import transforms from .data import ( patchify, get_stride_margin, - get_nonoverlap_patches, is_image_rgb, ) -from featureforest.SAM import ( - ENCODER_OUT_CHANNELS, EMBED_PATCH_CHANNELS, - sam_transform -) +from featureforest.models.base import BaseModelAdapter + +def get_slice_features( + image: np.ndarray, + patch_size: int, + overlap: int, + model_adapter: BaseModelAdapter, + device: torch.device, + storage_group: h5py.Group +) -> None: + """Extract the model features for one slice and save them into storage file. -def get_sam_embeddings_for_slice( - image, patch_size, overlap, - sam_encoder, device, storage_group: h5py.Group -): - """get sam features for one slice.""" + Args: + image (np.ndarray): _description_ + patch_size (int): _description_ + overlap (int): _description_ + model_adapter (BaseModelAdapter): _description_ + device (torch.device): _description_ + storage_group (h5py.Group): _description_ + """ img_height, img_width = image.shape[:2] # image to torch tensor img_data = torch.from_numpy(image).to(torch.float32) / 255.0 @@ -31,23 +39,13 @@ def get_sam_embeddings_for_slice( else: img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1) - # to resize encoder output back to the input patch size - embedding_transform = transforms.Compose([ - transforms.Resize( - (patch_size, patch_size), - interpolation=transforms.InterpolationMode.BICUBIC, - antialias=True - ), - # transforms.GaussianBlur(kernel_size=3, sigma=1.0) - ]) - # get input patches data_patches = patchify(img_data, patch_size, overlap) num_patches = len(data_patches) batch_size = 10 num_batches = int(np.ceil(num_patches / batch_size)) # prepare storage for the slice embeddings - total_channels = ENCODER_OUT_CHANNELS + EMBED_PATCH_CHANNELS + total_channels = model_adapter.get_total_output_channels() stride, _ = get_stride_margin(patch_size, overlap) dataset = storage_group.create_dataset( "sam", shape=( @@ -56,30 +54,25 @@ def get_sam_embeddings_for_slice( ) # get sam encoder output for image patches - with torch.no_grad(): - print("\ngetting SAM encoder & patch_embed output:") - for b_idx in np_progress( - range(num_batches), desc="getting SAM encoder & patch_embed outputs" - ): - print(f"batch #{b_idx + 1} of {num_batches}") - start = b_idx * batch_size - end = start + batch_size - output, embed_output, _ = sam_encoder( - sam_transform(data_patches[start: end]).to(device) - ) - # output: Bx256x64x64, embed_output: Bx64x256x256 - # after transform: Bx256x512x512, embed_output: Bx64x512x512 - # target patch: B, target_size, target_size, C - num_out = output.shape[0] - dataset[ - start: start + num_out, :, :, :ENCODER_OUT_CHANNELS - ] = get_nonoverlap_patches( - embedding_transform(output.cpu()), - patch_size, overlap - ) - dataset[ - start: start + num_out, :, :, ENCODER_OUT_CHANNELS: - ] = get_nonoverlap_patches( - embedding_transform(embed_output.cpu()), - patch_size, overlap - ) + print("\nextracting slice features:") + for b_idx in np_progress( + range(num_batches), desc="extracting slice feature:" + ): + print(f"batch #{b_idx + 1} of {num_batches}") + start = b_idx * batch_size + end = start + batch_size + slice_features = model_adapter.get_features_patches( + data_patches[start: end].to(device) + ) + if not isinstance(slice_features, tuple): + # model has only one output + num_out = slice_features.shape[0] # to take care of the last batch size + dataset[start: start + num_out] = slice_features + else: + # model has more than one output: put them into storage one by one + ch_start = 0 + for feat in slice_features: + num_out = feat.shape[0] + ch_end = ch_start + feat.shape[-1] # number of features + dataset[start: start + num_out, :, :, ch_start: ch_end] = feat + ch_start = ch_end diff --git a/src/featureforest/utils/postprocess_with_sam.py b/src/featureforest/utils/postprocess_with_sam.py index ee82156..63115f6 100644 --- a/src/featureforest/utils/postprocess_with_sam.py +++ b/src/featureforest/utils/postprocess_with_sam.py @@ -2,7 +2,7 @@ import cv2 import torch -from featureforest.SAM import SamPredictor +from segment_anything import SamPredictor def get_watershed_bboxes(image):