diff --git a/scripts_and_studies/coregistration_study/coregistration_study_utils.py b/scripts_and_studies/coregistration_study/coregistration_study_utils.py index 2974bd3..cf32781 100644 --- a/scripts_and_studies/coregistration_study/coregistration_study_utils.py +++ b/scripts_and_studies/coregistration_study/coregistration_study_utils.py @@ -15,6 +15,9 @@ from pyraws.utils.visualization_utils import equalize_tensor import matplotlib.cm as cm import torch +import kornia +from kornia.feature import * + def get_shift_SuperGlue_profiling( @@ -112,6 +115,120 @@ def get_shift_SuperGlue_profiling( return shift_mean +def get_shift_SIFT_profiling( + b0, + b1, + equalize=True, + n_std=2, + device=torch.device("cpu"), +): + """Get a shift between two bands of a specific event by using SuperGlue. + + Args: + b0 (torch.tensor): tensor containing band 0 to coregister. + b1 (torch.tensor): tensor containing band 1 to coregister. + equalize (bool, optional): if True, equalization is performed. Defaults to True. + n_std (int, optional): Outliers are saturated for equalization at histogram_mean*- n_std * histogram_std. + Defaults to 2. + device (torch.device, optional): torch.device. Defaults to torch.device("cpu"). + + Returns: + float: mean value of the shift. + torch.tensor: band 0. + torch.tensor: band 1. + dict: granule info. + float: number of matched kyepoints. + """ + # Aux: + def compute_offsets(image1, image2, device, verbose = False): + assert len(image1.shape) == 2, 'Error with shapes of image1' + assert len(image2.shape) == 2, 'Error with shapes of image2' + + PS = 16 + # Initialize SIFT descriptor + sift = kornia.feature.SIFTDescriptor(PS, rootsift=True).to(device) + descriptor = sift + + # Set up components for feature detection + resp = kornia.feature.BlobDoG() + scale_pyr = kornia.geometry.ScalePyramid(3, 1.6, PS, double_image=True) + nms = kornia.geometry.ConvQuadInterp3d(10) + n_features = 4000 + detector = kornia.feature.ScaleSpaceDetector( + n_features, + resp_module=resp, + scale_space_response=True, # Required for DoG + nms_module=nms, + scale_pyr_module=scale_pyr, + ori_module=kornia.feature.LAFOrienter(19), + mr_size=6.0, + minima_are_also_good=True + ).to(device) + + # Process each image + def process_image(img): + with torch.no_grad(): + lafs, _ = detector(img) + patches = kornia.feature.extract_patches_from_pyramid(img, lafs, PS) + B, N, CH, H, W = patches.size() + descs = descriptor(patches.view(B * N, CH, H, W)).view(B, N, -1) + return lafs, descs + + lafs1, descs1 = process_image(image1.unsqueeze(0).unsqueeze(0)) + lafs2, descs2 = process_image(image2.unsqueeze(0).unsqueeze(0)) + # Match features between the two images + scores, matches = kornia.feature.match_snn(descs1[0], descs2[0], 0.95) + + # Compute Homography and inliers + src_pts = lafs1[0, matches[:, 0], :, 2].data.cpu().numpy() + dst_pts = lafs2[0, matches[:, 1], :, 2].data.cpu().numpy() + return src_pts, dst_pts + + bands = torch.zeros([b0.shape[0], b0.shape[1], 2], device=device) + bands[:, :, 0] = b0 + bands[:, :, 1] = b1 + if equalize: + l0_granule_tensor_equalized = equalize_tensor(bands[:, :, :2], n_std) + b0 = ( + l0_granule_tensor_equalized[:, :, 0] + / l0_granule_tensor_equalized[:, :, 0].max() + ) + b1 = ( + l0_granule_tensor_equalized[:, :, 1] + / l0_granule_tensor_equalized[:, :, 1].max() + ) + else: + b0 = bands[:, :, 0] / bands[:, :, 0].max() + b1 = bands[:, :, 1] / bands[:, :, 1].max() + + mkpts0, mkpts1 = compute_offsets(b0, b1, device=device) + + if len(mkpts1) and len(mkpts0): + shift = torch.tensor([x - y for (x, y) in zip(mkpts1, mkpts0)], device=device) + shift_v, shift_h = shift[:, 0], shift[:, 1] + shift_v_mean, shift_v_std = torch.mean(shift_v), torch.std(shift_v) + shift_h_mean, shift_h_std = torch.mean(shift_h), torch.std(shift_h) + shift_v = shift_v[ + torch.logical_and( + shift_v > shift_v_mean - shift_v_std, + shift_v < shift_v_mean + shift_v_std, + ) + ] + shift_h = shift_h[ + torch.logical_and( + shift_h > shift_h_mean - shift_h_std, + shift_h < shift_h_mean + shift_h_std, + ) + ] + shift_mean = torch.round( + torch.tensor([-shift_h.mean(), -shift_v.mean()], device=device) + ) + else: + return [None, None] + return shift_mean + + + def get_shift_SuperGlue( event_name, raw_granule,