Skip to content

Commit

Permalink
feat: handle pillow images as input (#35)
Browse files Browse the repository at this point in the history
Co-authored-by: Johan Edstedt <johan.edstedt@liu.se>
  • Loading branch information
dgcnz and Parskatt committed Jun 16, 2024
1 parent 2d869bb commit 36389ef
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*__pycache__*
vis*
workspace*
.venv
.DS_Store
jobs/*
*ignore_me*
*.pth
Expand Down
12 changes: 6 additions & 6 deletions roma/models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from roma.utils.local_correlation import local_correlation
from roma.utils.utils import cls_to_flow_refine
from roma.utils.kde import kde
from typing import Union

class ConvRefiner(nn.Module):
def __init__(
Expand Down Expand Up @@ -610,8 +611,8 @@ def recrop(self, certainty, image_path):
@torch.inference_mode()
def match(
self,
im_A_path,
im_B_path,
im_A_path: Union[str, os.PathLike, Image.Image],
im_B_path: Union[str, os.PathLike, Image.Image],
*args,
batched=False,
device = None,
Expand All @@ -621,8 +622,8 @@ def match(
if isinstance(im_A_path, (str, os.PathLike)):
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
else:
# Assume its not a path
im_A, im_B = im_A_path, im_B_path
im_A, im_B = im_A_path, im_B_path

symmetric = self.symmetric
self.train(False)
with torch.no_grad():
Expand Down Expand Up @@ -672,13 +673,12 @@ def match(
resize=(hs, ws), normalize=True
)
if self.recrop_upsample:
raise NotImplementedError("recrop_upsample not implemented")
certainty = corresps[finest_scale]["certainty"]
print(certainty.shape)
im_A = self.recrop(certainty[0,0], im_A_path)
im_B = self.recrop(certainty[1,0], im_B_path)
#TODO: need to adjust corresps when doing this
else:
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
im_A, im_B = test_transform((im_A, im_B))
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
Expand Down

0 comments on commit 36389ef

Please sign in to comment.