diff --git a/tests/metrics/test_reference_recall.py b/tests/metrics/test_reference_recall.py index 1ecf6ef..a44bae8 100644 --- a/tests/metrics/test_reference_recall.py +++ b/tests/metrics/test_reference_recall.py @@ -11,12 +11,11 @@ def test_reference_recall(): the recall value should be equal to 0.5 """ localization_pipeline = create_localization_pipeline() - recall, mask = avl.reference_recall( + recall = avl.reference_recall( queries, localization_pipeline, k_closest=2, threshold=10 ) assert np.isclose(recall, 0.5) - assert mask == [True, False] def test_reference_recall_low_threshold(): @@ -26,9 +25,8 @@ def test_reference_recall_low_threshold(): so the recall value should be equal to 0 """ localization_pipeline = create_localization_pipeline() - recall, mask = avl.reference_recall( + recall = avl.reference_recall( queries, localization_pipeline, k_closest=2, threshold=1 ) assert np.isclose(recall, 0) - assert mask == [False, False] diff --git a/tests/metrics/test_retrieval_recall.py b/tests/metrics/test_retrieval_recall.py index d9ab110..6e97210 100644 --- a/tests/metrics/test_retrieval_recall.py +++ b/tests/metrics/test_retrieval_recall.py @@ -12,9 +12,8 @@ def test_retrieval_recall(): """ localization_pipeline = create_localization_pipeline() retrieval_system = localization_pipeline.retrieval_system - recall, mask = avl.retrieval_recall( + recalls = avl.retrieval_recall( queries, retrieval_system, vpr_k_closest=2, feature_matcher_k_closest=1 ) - assert np.isclose(recall, 0.5) - assert mask == [True, False] + assert np.isclose(recalls[0], 0.5)