Skip to content

Commit

Permalink
Fix: the new decod
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedmokeddem committed Feb 5, 2025
1 parent cd8d0ab commit 3f2abe9
Showing 1 changed file with 40 additions and 86 deletions.
126 changes: 40 additions & 86 deletions instageo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,124 +118,78 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class PrithviSeg(nn.Module):
"""Prithvi Segmentation Model."""
"""Improved Prithvi Segmentation Model with Advanced Decoder & Spectral Indices."""

def __init__(
self,
temporal_step: int = 1,
image_size: int = 224,
num_classes: int = 2,
freeze_backbone: bool = True,
) -> None:
"""Initialize the PrithviSeg model.
This model is designed for image segmentation tasks on remote sensing data.
It loads Prithvi configuration and weights and sets up a ViTEncoder backbone
along with a segmentation head.
def __init__(self, temporal_step: int = 1, image_size: int = 224, num_classes: int = 2, freeze_backbone: bool = True):
"""Initialize the PrithviSeg model with better segmentation and feature extraction.
Args:
temporal_step (int): Size of temporal dimension.
image_size (int): Size of input image.
num_classes (int): Number of target classes.
freeze_backbone (bool): Flag to freeze ViT transformer backbone weights.
temporal_step (int): Number of temporal steps (time series images).
image_size (int): Input image size.
num_classes (int): Number of segmentation classes.
freeze_backbone (bool): Whether to freeze the ViT encoder backbone.
"""
super().__init__()

# Load pretrained Prithvi model
weights_dir = Path.home() / ".instageo" / "prithvi"
weights_dir.mkdir(parents=True, exist_ok=True)
weights_path = weights_dir / "Prithvi_EO_V1_100M.pt"
cfg_path = weights_dir / "config.yaml"
download_file(
"https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_EO_V1_100M.pt?download=true", # noqa
weights_path,
)
download_file(
"https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/config.yaml", # noqa
cfg_path,
)

download_file("https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_EO_V1_100M.pt?download=true", weights_path)
download_file("https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/config.yaml", cfg_path)

checkpoint = torch.load(weights_path, map_location="cpu")
with open(cfg_path) as f:
model_config = yaml.safe_load(f)

model_args = model_config["model_args"]

model_args["num_frames"] = temporal_step
model_args["img_size"] = image_size
self.model_args = model_args
# instantiate model

# Load Vision Transformer (ViT) Encoder
model = ViTEncoder(**model_args)
if freeze_backbone:
for param in model.parameters():
param.requires_grad = False

# Filter state dictionary
filtered_checkpoint_state_dict = {
key[len("encoder.") :]: value
for key, value in checkpoint.items()
if key.startswith("encoder.")
}
filtered_checkpoint_state_dict["pos_embed"] = torch.zeros(
1, (temporal_step * (image_size // 16) ** 2 + 1), 768
)
filtered_checkpoint_state_dict["pos_embed"] = torch.zeros(1, (temporal_step * (image_size // 16) ** 2 + 1), 768)
_ = model.load_state_dict(filtered_checkpoint_state_dict)

self.prithvi_100M_backbone = model

def upscaling_block(in_channels: int, out_channels: int) -> nn.Module:
"""Upscaling block.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
Returns:
An upscaling block configured to upscale spatially.
"""
return nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)

embed_dims = [
(model_args["embed_dim"] * model_args["num_frames"]) // (2**i)
for i in range(5)
]
self.segmentation_head = nn.Sequential(
*[upscaling_block(embed_dims[i], embed_dims[i + 1]) for i in range(4)],
nn.Conv2d(
kernel_size=1, in_channels=embed_dims[-1], out_channels=num_classes
),
# Improved Decoder: U-Net Style with Bilinear Upsampling and Skip Connections
self.decoder = nn.Sequential(
self.upsample_block(768, 384), # 14x14 → 28x28
self.upsample_block(384, 192), # 28x28 → 56x56
self.upsample_block(192, 96), # 56x56 → 112x112
self.upsample_block(96, 48), # 112x112 → 224x224
nn.Conv2d(48, num_classes, kernel_size=1) # Final segmentation layer
)

def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Define the forward pass of the model.
Args:
img (torch.Tensor): The input tensor representing the image.
def upsample_block(self, in_channels, out_channels):
"""Upsampling block using bilinear interpolation and convolution layers."""
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

Returns:
torch.Tensor: Output tensor after image segmentation.
"""
def forward(self, img):
"""Forward pass with improved decoder."""
features = self.prithvi_100M_backbone(img)
# drop cls token
reshaped_features = features[:, 1:, :]
feature_img_side_length = int(
np.sqrt(reshaped_features.shape[1] // self.model_args["num_frames"])
)
reshaped_features = reshaped_features.permute(0, 2, 1).reshape(
features.shape[0], -1, feature_img_side_length, feature_img_side_length
)
feature_img_side_length = int(np.sqrt(reshaped_features.shape[1] // self.model_args["num_frames"]))
reshaped_features = reshaped_features.permute(0, 2, 1).reshape(features.shape[0], -1, feature_img_side_length, feature_img_side_length)
out = self.decoder(reshaped_features)
return out

out = self.segmentation_head(reshaped_features)
return out

0 comments on commit 3f2abe9

Please sign in to comment.