Skip to content

Commit

Permalink
chore: change naming
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechCermak committed Nov 12, 2024
1 parent eea8a04 commit 3f4028d
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions wildlife_tools/similarity/wildfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,49 +129,49 @@ class WildFusion:
WildFusion can be used with a limited computational budget by applying it only B times per query
image. It uses a fast-to-compute similarity score (e.g., cosine similarity of deep features) provided
by the priority_matcher to construct a shortlist of the most promising matches for a given query.
by the priority_pipeline to construct a shortlist of the most promising matches for a given query.
Final ranking is then based on WildFusion scores calculated for the pairs in the shortlist.
"""

def __init__(
self,
calibrated_matchers: list[SimilarityPipeline],
priority_matcher: SimilarityPipeline | None = None,
calibrated_pipelines: list[SimilarityPipeline],
priority_pipeline: SimilarityPipeline | None = None,
):
"""
Args:
calibrated_matchers (list[SimilarityPipeline]): List of SimilarityPipeline objects.
priority_matcher (SimilarityPipeline, optional): Fast-to-compute similarity matcher
calibrated_pipelines (list[SimilarityPipeline]): List of SimilarityPipeline objects.
priority_pipeline (SimilarityPipeline, optional): Fast-to-compute similarity matcher
used for shortlisting.
"""

self.calibrated_matchers = calibrated_matchers
self.priority_matcher = priority_matcher
self.calibrated_pipelines = calibrated_pipelines
self.priority_pipeline = priority_pipeline

def fit_calibration(self, dataset0: ImageDataset, dataset1: ImageDataset):
"""
Fit the all calibration models for all matchers in `calibrated_matchers`.
Fit the all calibration models for all matchers in `calibrated_pipelines`.
Args:
dataset0 (ImageDataset): The first dataset (e.g., part of training set).
dataset1 (ImageDataset): The second dataset (e.g., part of training set).
"""

for matcher in self.calibrated_matchers:
for matcher in self.calibrated_pipelines:
matcher.fit_calibration(dataset0, dataset1)

if self.priority_matcher is not None:
self.priority_matcher.fit_calibration(dataset0, dataset1)
if self.priority_pipeline is not None:
self.priority_pipeline.fit_calibration(dataset0, dataset1)

def get_priority_pairs(
self, dataset0: ImageDataset, dataset1: ImageDataset, B: int
) -> np.ndarray:
"""Implements shortlisting strategy for selection of most relevant pairs."""

if self.priority_matcher is None:
if self.priority_pipeline is None:
raise ValueError("Priority matcher is not assigned.")

priority = self.priority_matcher(dataset0, dataset1)
priority = self.priority_pipeline(dataset0, dataset1)
_, idx1 = torch.topk(torch.tensor(priority), min(B, priority.shape[1]))
idx0 = np.indices(idx1.numpy().shape)[0]
grid_indices = np.stack([idx0.flatten(), idx1.flatten()]).T
Expand All @@ -197,7 +197,7 @@ def __call__(
pairs (list of tuples, optional): Specific pairs of images to compute similarity scores.
If None, compute similarity scores for all pairs.
Is ignored if `B` is provided.
B (int, optional): Number of pairs to compute similarity scores for. Required `priority_matcher` to be assigned.
B (int, optional): Number of pairs to compute similarity scores for. Required `priority_pipeline` to be assigned.
If None, compute similarity scores for all pairs.
Returns:
Expand All @@ -209,7 +209,7 @@ def __call__(
pairs = self.get_priority_pairs(dataset0, dataset1, B=B)

scores = []
for matcher in self.calibrated_matchers:
for matcher in self.calibrated_pipelines:
scores.append(matcher(dataset0, dataset1, pairs=pairs))

score_combined = np.mean(scores, axis=0)
Expand Down

0 comments on commit 3f4028d

Please sign in to comment.