From bac067330c092e263270791cc7110d9041075ab2 Mon Sep 17 00:00:00 2001 From: Vojtech Cermak Date: Wed, 27 Nov 2024 11:43:07 +0100 Subject: [PATCH] refactor: move visualisation function to utils and clean up imports --- wildlife_tools/similarity/pairwise/base.py | 23 ----------- wildlife_tools/similarity/pairwise/utils.py | 44 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 23 deletions(-) create mode 100644 wildlife_tools/similarity/pairwise/utils.py diff --git a/wildlife_tools/similarity/pairwise/base.py b/wildlife_tools/similarity/pairwise/base.py index bd89c87..460a7c7 100644 --- a/wildlife_tools/similarity/pairwise/base.py +++ b/wildlife_tools/similarity/pairwise/base.py @@ -1,7 +1,5 @@ import itertools -import cv2 -import matplotlib.pyplot as plt import numpy as np import torch from tqdm import tqdm @@ -10,27 +8,6 @@ from .collectors import CollectCounts -def visualise_matches(img0, keypoints0, img1, keypoints1): - keypoints0 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints0] - keypoints1 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints1] - - # Create dummy matches (DMatch objects) - matches = [cv2.DMatch(i, i, 0) for i in range(len(keypoints0))] - - # Draw matches - img_matches = cv2.drawMatches( - img0, - keypoints0, - img1, - keypoints1, - matches, - None, - flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS, - ) - plt.imshow(img_matches) - plt.show() - - class PairDataset(torch.utils.data.IterableDataset): """ Create iterable style dataset from two mapping style datasets. diff --git a/wildlife_tools/similarity/pairwise/utils.py b/wildlife_tools/similarity/pairwise/utils.py new file mode 100644 index 0000000..0a252a8 --- /dev/null +++ b/wildlife_tools/similarity/pairwise/utils.py @@ -0,0 +1,44 @@ +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image + + +def visualise_matches(img0: Image, keypoints0: np.ndarray, img1: Image, keypoints1: list, ax=None): + """ + Visualise matches between two images. + + Args: + img0 (np.array or PIL Image): First image. + keypoints0 (np.array): Keypoints in the first image. + img1 (np.array): Second image. + keypoints1 (np.array): Keypoints in the second image. + ax (matplotlib.axes.Axes, optional): Matplotlib axis to draw on. If None, a new axis is created. + """ + + # Convert images to numpy arrays + img0 = np.array(img0) + img1 = np.array(img1) + + keypoints0 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints0] + keypoints1 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints1] + + # Create dummy matches (DMatch objects) + matches = [cv2.DMatch(i, i, 0) for i in range(len(keypoints0))] + + # Draw matches + img_matches = cv2.drawMatches( + img0, + keypoints0, + img1, + keypoints1, + matches, + None, + flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS, + ) + + # Plotting + if ax is None: + _, ax = plt.subplots() + ax.imshow(img_matches) + ax.axis("off")