Skip to content

Commit

Permalink
Pushing the function get_shift_SIFT_profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
sirbastiano authored Nov 16, 2023
1 parent 57458d8 commit 671b98e
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions scripts_and_studies/coregistration_study/coregistration_study_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from pyraws.utils.visualization_utils import equalize_tensor
import matplotlib.cm as cm
import torch
import kornia
from kornia.feature import *



def get_shift_SuperGlue_profiling(
Expand Down Expand Up @@ -112,6 +115,120 @@ def get_shift_SuperGlue_profiling(
return shift_mean


def get_shift_SIFT_profiling(
b0,
b1,
equalize=True,
n_std=2,
device=torch.device("cpu"),
):
"""Get a shift between two bands of a specific event by using SuperGlue.
Args:
b0 (torch.tensor): tensor containing band 0 to coregister.
b1 (torch.tensor): tensor containing band 1 to coregister.
equalize (bool, optional): if True, equalization is performed. Defaults to True.
n_std (int, optional): Outliers are saturated for equalization at histogram_mean*- n_std * histogram_std.
Defaults to 2.
device (torch.device, optional): torch.device. Defaults to torch.device("cpu").
Returns:
float: mean value of the shift.
torch.tensor: band 0.
torch.tensor: band 1.
dict: granule info.
float: number of matched kyepoints.
"""
# Aux:
def compute_offsets(image1, image2, device, verbose = False):
assert len(image1.shape) == 2, 'Error with shapes of image1'
assert len(image2.shape) == 2, 'Error with shapes of image2'

PS = 16
# Initialize SIFT descriptor
sift = kornia.feature.SIFTDescriptor(PS, rootsift=True).to(device)
descriptor = sift

# Set up components for feature detection
resp = kornia.feature.BlobDoG()
scale_pyr = kornia.geometry.ScalePyramid(3, 1.6, PS, double_image=True)
nms = kornia.geometry.ConvQuadInterp3d(10)
n_features = 4000
detector = kornia.feature.ScaleSpaceDetector(
n_features,
resp_module=resp,
scale_space_response=True, # Required for DoG
nms_module=nms,
scale_pyr_module=scale_pyr,
ori_module=kornia.feature.LAFOrienter(19),
mr_size=6.0,
minima_are_also_good=True
).to(device)

# Process each image
def process_image(img):
with torch.no_grad():
lafs, _ = detector(img)
patches = kornia.feature.extract_patches_from_pyramid(img, lafs, PS)
B, N, CH, H, W = patches.size()
descs = descriptor(patches.view(B * N, CH, H, W)).view(B, N, -1)
return lafs, descs

lafs1, descs1 = process_image(image1.unsqueeze(0).unsqueeze(0))
lafs2, descs2 = process_image(image2.unsqueeze(0).unsqueeze(0))
# Match features between the two images
scores, matches = kornia.feature.match_snn(descs1[0], descs2[0], 0.95)

# Compute Homography and inliers
src_pts = lafs1[0, matches[:, 0], :, 2].data.cpu().numpy()
dst_pts = lafs2[0, matches[:, 1], :, 2].data.cpu().numpy()
return src_pts, dst_pts

bands = torch.zeros([b0.shape[0], b0.shape[1], 2], device=device)
bands[:, :, 0] = b0
bands[:, :, 1] = b1
if equalize:
l0_granule_tensor_equalized = equalize_tensor(bands[:, :, :2], n_std)
b0 = (
l0_granule_tensor_equalized[:, :, 0]
/ l0_granule_tensor_equalized[:, :, 0].max()
)
b1 = (
l0_granule_tensor_equalized[:, :, 1]
/ l0_granule_tensor_equalized[:, :, 1].max()
)
else:
b0 = bands[:, :, 0] / bands[:, :, 0].max()
b1 = bands[:, :, 1] / bands[:, :, 1].max()

mkpts0, mkpts1 = compute_offsets(b0, b1, device=device)

if len(mkpts1) and len(mkpts0):
shift = torch.tensor([x - y for (x, y) in zip(mkpts1, mkpts0)], device=device)
shift_v, shift_h = shift[:, 0], shift[:, 1]
shift_v_mean, shift_v_std = torch.mean(shift_v), torch.std(shift_v)
shift_h_mean, shift_h_std = torch.mean(shift_h), torch.std(shift_h)
shift_v = shift_v[
torch.logical_and(
shift_v > shift_v_mean - shift_v_std,
shift_v < shift_v_mean + shift_v_std,
)
]
shift_h = shift_h[
torch.logical_and(
shift_h > shift_h_mean - shift_h_std,
shift_h < shift_h_mean + shift_h_std,
)
]
shift_mean = torch.round(
torch.tensor([-shift_h.mean(), -shift_v.mean()], device=device)
)
else:
return [None, None]
return shift_mean



def get_shift_SuperGlue(
event_name,
raw_granule,
Expand Down

0 comments on commit 671b98e

Please sign in to comment.