From 36389eff4089d7b3f4a752c054e8fd0714af7177 Mon Sep 17 00:00:00 2001 From: Diego Date: Mon, 17 Jun 2024 00:52:47 +0200 Subject: [PATCH] feat: handle pillow images as input (#35) Co-authored-by: Johan Edstedt --- .gitignore | 2 ++ roma/models/matcher.py | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 58b2f36..7f8801c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ *__pycache__* vis* workspace* +.venv +.DS_Store jobs/* *ignore_me* *.pth diff --git a/roma/models/matcher.py b/roma/models/matcher.py index 4497e87..ef5cd01 100644 --- a/roma/models/matcher.py +++ b/roma/models/matcher.py @@ -14,6 +14,7 @@ from roma.utils.local_correlation import local_correlation from roma.utils.utils import cls_to_flow_refine from roma.utils.kde import kde +from typing import Union class ConvRefiner(nn.Module): def __init__( @@ -610,8 +611,8 @@ def recrop(self, certainty, image_path): @torch.inference_mode() def match( self, - im_A_path, - im_B_path, + im_A_path: Union[str, os.PathLike, Image.Image], + im_B_path: Union[str, os.PathLike, Image.Image], *args, batched=False, device = None, @@ -621,8 +622,8 @@ def match( if isinstance(im_A_path, (str, os.PathLike)): im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") else: - # Assume its not a path - im_A, im_B = im_A_path, im_B_path + im_A, im_B = im_A_path, im_B_path + symmetric = self.symmetric self.train(False) with torch.no_grad(): @@ -672,13 +673,12 @@ def match( resize=(hs, ws), normalize=True ) if self.recrop_upsample: + raise NotImplementedError("recrop_upsample not implemented") certainty = corresps[finest_scale]["certainty"] print(certainty.shape) im_A = self.recrop(certainty[0,0], im_A_path) im_B = self.recrop(certainty[1,0], im_B_path) #TODO: need to adjust corresps when doing this - else: - im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") im_A, im_B = test_transform((im_A, im_B)) im_A, im_B = im_A[None].to(device), im_B[None].to(device) scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))