diff --git a/source/MaaFramework/Vision/NeuralNetworkDetector.cpp b/source/MaaFramework/Vision/NeuralNetworkDetector.cpp index ac864b218..70d8066cd 100644 --- a/source/MaaFramework/Vision/NeuralNetworkDetector.cpp +++ b/source/MaaFramework/Vision/NeuralNetworkDetector.cpp @@ -34,7 +34,7 @@ void NeuralNetworkDetector::analyze() auto start_time = std::chrono::steady_clock::now(); auto results = detect(); - add_results(std::move(results), param_.expected); + add_results(std::move(results), param_.expected, param_.thresholds); cherry_pick(); @@ -146,11 +146,26 @@ NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect() const return nms_results; } -void NeuralNetworkDetector::add_results(ResultsVec results, const std::vector& expected) +void NeuralNetworkDetector::add_results(ResultsVec results, const std::vector& expected, const std::vector& thresholds) { - std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { - return std::ranges::find(expected, res.cls_index) != expected.end(); - }); + if (expected.size() != thresholds.size()) { + LogError << name_ << VAR(uid_) << "expected.size() != thresholds.size()" << VAR(expected) << VAR(thresholds); + return; + } + + for (size_t i = 0; i != expected.size(); ++i) { + size_t exp = expected.at(i); + auto it = std::ranges::find(results, exp, std::mem_fn(&Result::cls_index)); + if (it == results.end()) { + continue; + } + const Result& res = *it; + double thres = thresholds.at(i); + if (res.score < thres) { + continue; + } + filtered_results_.emplace_back(res); + } merge_vector_(all_results_, std::move(results)); } diff --git a/source/MaaFramework/Vision/NeuralNetworkDetector.h b/source/MaaFramework/Vision/NeuralNetworkDetector.h index 96f1a26ff..46427ce54 100644 --- a/source/MaaFramework/Vision/NeuralNetworkDetector.h +++ b/source/MaaFramework/Vision/NeuralNetworkDetector.h @@ -41,7 +41,7 @@ class NeuralNetworkDetector ResultsVec detect() const; - void add_results(ResultsVec results, const std::vector& expected); + void add_results(ResultsVec results, const std::vector& expected, const std::vector& thresholds); void cherry_pick(); private: