Skip to content

Commit ba5d22b

Browse files
committed
metrics updated
1 parent 75b9742 commit ba5d22b

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

aero_vloc/metrics/reference_recall.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import numpy as np
15-
16-
from typing import Tuple
17-
1814
from aero_vloc.localization_pipeline import LocalizationPipeline
1915
from aero_vloc.metrics.utils import calculate_distance
2016
from aero_vloc.primitives import UAVSeq
@@ -25,7 +21,7 @@ def reference_recall(
2521
localization_pipeline: LocalizationPipeline,
2622
k_closest: int,
2723
threshold: int,
28-
) -> Tuple[float, list[bool]]:
24+
) -> float:
2925
"""
3026
The metric finds the number of correctly matched frames based on georeference error
3127
@@ -36,17 +32,16 @@ def reference_recall(
3632
:param threshold: The distance between query and reference geocoordinates,
3733
below which the frame will be considered correctly matched
3834
39-
:return: Recall value, boolean mask showing which frames were considered as successfully matched
35+
:return: Recall value
4036
"""
41-
mask = []
37+
recall_value = 0
4238
localization_results = localization_pipeline(uav_seq, k_closest)
4339
for loc_res, uav_image in zip(localization_results, uav_seq):
4440
if loc_res is not None:
4541
lat, lon = loc_res
4642
error = calculate_distance(
4743
lat, lon, uav_image.gt_latitude, uav_image.gt_longitude
4844
)
49-
mask.append(error < threshold)
50-
else:
51-
mask.append(False)
52-
return np.sum(mask) / len(uav_seq.uav_images), mask
45+
if error < threshold:
46+
recall_value += 1
47+
return recall_value / len(uav_seq.uav_images)

aero_vloc/metrics/retrieval_recall.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
import numpy as np
1515

16-
from typing import Tuple
17-
1816
from aero_vloc.primitives import UAVSeq
1917
from aero_vloc.retrieval_system import RetrievalSystem
2018

@@ -24,7 +22,7 @@ def retrieval_recall(
2422
retrieval_system: RetrievalSystem,
2523
vpr_k_closest: int,
2624
feature_matcher_k_closest: int | None,
27-
) -> Tuple[float, list[bool]]:
25+
) -> np.ndarray:
2826
"""
2927
The metric finds the number of correctly matched frames based on retrieval results
3028
@@ -34,15 +32,18 @@ def retrieval_recall(
3432
:param feature_matcher_k_closest: Determines how many best images are to be obtained with the feature matcher
3533
If it is None, then the feature matcher turns off
3634
37-
:return: Recall value, boolean mask showing which frames were considered as successfully matched
35+
:return: Array of Recall values for all N < vpr_k_closest,
36+
or for all N < feature_matcher_k_closest if it is not None
3837
"""
39-
mask = []
38+
if feature_matcher_k_closest is not None:
39+
recalls = np.zeros(feature_matcher_k_closest)
40+
else:
41+
recalls = np.zeros(vpr_k_closest)
4042
for uav_image in uav_seq:
41-
localized = False
4243
predictions, _, _ = retrieval_system(
4344
uav_image, vpr_k_closest, feature_matcher_k_closest
4445
)
45-
for prediction in predictions:
46+
for i, prediction in enumerate(predictions):
4647
map_tile = retrieval_system.sat_map[prediction]
4748
if (
4849
map_tile.top_left_lat
@@ -53,8 +54,9 @@ def retrieval_recall(
5354
< uav_image.gt_longitude
5455
< map_tile.bottom_right_lon
5556
):
56-
localized = True
57+
recalls[i:] += 1
5758
break
58-
mask.append(localized)
59+
5960
retrieval_system.end_of_query_seq()
60-
return np.sum(mask) / len(uav_seq.uav_images), mask
61+
recalls = recalls / len(uav_seq.uav_images)
62+
return recalls

0 commit comments

Comments
 (0)