Skip to content

Commit 0fc57f9

Browse files
committed
added model adapter base
Former-commit-id: f01df76 Former-commit-id: 19244da
1 parent 669d171 commit 0fc57f9

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

src/featureforest/models/base.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch import Tensor
4+
from torchvision.transforms import v2 as tv_transforms2
5+
6+
from ..utils.data import (
7+
get_nonoverlapped_patches,
8+
)
9+
10+
11+
class BaseModelAdapter:
12+
"""Base class for adapting any models in featureforest.
13+
"""
14+
def __init__(
15+
self,
16+
model: nn.Module,
17+
input_transforms: tv_transforms2.Compose,
18+
patch_size: int,
19+
overlap: int,
20+
) -> None:
21+
"""Initialization function
22+
23+
Args:
24+
model (nn.Module): the pytorch model (e.g. a ViT encoder)
25+
input_transforms (tv_transforms2.Compose): input transformations for the specific model
26+
patch_size (int): input patch size
27+
overlap (int): input patch overlap
28+
"""
29+
self.model = model
30+
self.input_transforms = input_transforms
31+
self.patch_size = patch_size
32+
self.overlap = overlap
33+
# to transform feature patches to the original patch size
34+
self.embedding_transform = tv_transforms2.Compose([
35+
tv_transforms2.Resize(
36+
(self.patch_size, self.patch_size),
37+
interpolation=tv_transforms2.InterpolationMode.BICUBIC,
38+
antialias=True
39+
),
40+
])
41+
42+
def get_features_patches(
43+
self, in_patches: Tensor
44+
) -> Tensor:
45+
"""Returns a tensor of model's extracted features.
46+
This function is more like an abstract function, and should be overridden.
47+
48+
Args:
49+
in_patches (Tensor): input patches
50+
51+
Returns:
52+
Tensor: model's extracted features
53+
"""
54+
# get the model output
55+
with torch.no_grad():
56+
out_features = self.model(self.input_transforms(in_patches))
57+
# assert self.patch_size == out_features.shape[-1]
58+
59+
# get non-overlapped feature patches
60+
feature_patches = get_nonoverlapped_patches(
61+
self.embedding_transform(out_features.cpu()),
62+
self.patch_size, self.overlap
63+
)
64+
65+
return feature_patches
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch import Tensor
4+
from torchvision.transforms import v2 as tv_transforms2
5+
6+
from .base import BaseModelAdapter
7+
from ..utils.data import (
8+
get_nonoverlapped_patches,
9+
)
10+
11+
12+
class MobileSAM(BaseModelAdapter):
13+
"""MobileSAM model adapter
14+
"""
15+
def __init__(
16+
self,
17+
model: nn.Module,
18+
input_transforms: tv_transforms2.Compose,
19+
patch_size: int,
20+
overlap: int,
21+
) -> None:
22+
super().__init__(model, input_transforms, patch_size, overlap)
23+
24+
def get_features_patches(
25+
self, in_patches: Tensor
26+
) -> Tensor:
27+
# get the model output
28+
with torch.no_grad():
29+
out_features = self.model(self.input_transforms(in_patches))
30+
# assert self.patch_size == out_features.shape[-1]
31+
32+
# get non-overlapped feature patches
33+
feature_patches = get_nonoverlapped_patches(
34+
self.embedding_transform(out_features.cpu()),
35+
self.patch_size, self.overlap
36+
)
37+
38+
return feature_patches

0 commit comments

Comments
 (0)