13
13
# limitations under the License.
14
14
import numpy as np
15
15
16
- from typing import Tuple
17
-
18
16
from aero_vloc .primitives import UAVSeq
19
17
from aero_vloc .retrieval_system import RetrievalSystem
20
18
@@ -24,7 +22,7 @@ def retrieval_recall(
24
22
retrieval_system : RetrievalSystem ,
25
23
vpr_k_closest : int ,
26
24
feature_matcher_k_closest : int | None ,
27
- ) -> Tuple [ float , list [ bool ]] :
25
+ ) -> np . ndarray :
28
26
"""
29
27
The metric finds the number of correctly matched frames based on retrieval results
30
28
@@ -34,15 +32,18 @@ def retrieval_recall(
34
32
:param feature_matcher_k_closest: Determines how many best images are to be obtained with the feature matcher
35
33
If it is None, then the feature matcher turns off
36
34
37
- :return: Recall value, boolean mask showing which frames were considered as successfully matched
35
+ :return: Array of Recall values for all N < vpr_k_closest,
36
+ or for all N < feature_matcher_k_closest if it is not None
38
37
"""
39
- mask = []
38
+ if feature_matcher_k_closest is not None :
39
+ recalls = np .zeros (feature_matcher_k_closest )
40
+ else :
41
+ recalls = np .zeros (vpr_k_closest )
40
42
for uav_image in uav_seq :
41
- localized = False
42
43
predictions , _ , _ = retrieval_system (
43
44
uav_image , vpr_k_closest , feature_matcher_k_closest
44
45
)
45
- for prediction in predictions :
46
+ for i , prediction in enumerate ( predictions ) :
46
47
map_tile = retrieval_system .sat_map [prediction ]
47
48
if (
48
49
map_tile .top_left_lat
@@ -53,8 +54,9 @@ def retrieval_recall(
53
54
< uav_image .gt_longitude
54
55
< map_tile .bottom_right_lon
55
56
):
56
- localized = True
57
+ recalls [ i :] += 1
57
58
break
58
- mask . append ( localized )
59
+
59
60
retrieval_system .end_of_query_seq ()
60
- return np .sum (mask ) / len (uav_seq .uav_images ), mask
61
+ recalls = recalls / len (uav_seq .uav_images )
62
+ return recalls
0 commit comments