Skip to content

Commit

Permalink
refactor: move visualisation function to utils and clean up imports
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechCermak committed Nov 27, 2024
1 parent b72bd0c commit bac0673
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 23 deletions.
23 changes: 0 additions & 23 deletions wildlife_tools/similarity/pairwise/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import itertools

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
Expand All @@ -10,27 +8,6 @@
from .collectors import CollectCounts


def visualise_matches(img0, keypoints0, img1, keypoints1):
keypoints0 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints0]
keypoints1 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints1]

# Create dummy matches (DMatch objects)
matches = [cv2.DMatch(i, i, 0) for i in range(len(keypoints0))]

# Draw matches
img_matches = cv2.drawMatches(
img0,
keypoints0,
img1,
keypoints1,
matches,
None,
flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
)
plt.imshow(img_matches)
plt.show()


class PairDataset(torch.utils.data.IterableDataset):
"""
Create iterable style dataset from two mapping style datasets.
Expand Down
44 changes: 44 additions & 0 deletions wildlife_tools/similarity/pairwise/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


def visualise_matches(img0: Image, keypoints0: np.ndarray, img1: Image, keypoints1: list, ax=None):
"""
Visualise matches between two images.
Args:
img0 (np.array or PIL Image): First image.
keypoints0 (np.array): Keypoints in the first image.
img1 (np.array): Second image.
keypoints1 (np.array): Keypoints in the second image.
ax (matplotlib.axes.Axes, optional): Matplotlib axis to draw on. If None, a new axis is created.
"""

# Convert images to numpy arrays
img0 = np.array(img0)
img1 = np.array(img1)

keypoints0 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints0]
keypoints1 = [cv2.KeyPoint(int(x[0]), int(x[1]), 1) for x in keypoints1]

# Create dummy matches (DMatch objects)
matches = [cv2.DMatch(i, i, 0) for i in range(len(keypoints0))]

# Draw matches
img_matches = cv2.drawMatches(
img0,
keypoints0,
img1,
keypoints1,
matches,
None,
flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
)

# Plotting
if ax is None:
_, ax = plt.subplots()
ax.imshow(img_matches)
ax.axis("off")

0 comments on commit bac0673

Please sign in to comment.