Skip to content

Commit

Permalink
fix: nndetect threshold not work
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO committed Sep 29, 2024
1 parent d184a65 commit ef707e4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
25 changes: 20 additions & 5 deletions source/MaaFramework/Vision/NeuralNetworkDetector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void NeuralNetworkDetector::analyze()
auto start_time = std::chrono::steady_clock::now();

auto results = detect();
add_results(std::move(results), param_.expected);
add_results(std::move(results), param_.expected, param_.thresholds);

cherry_pick();

Expand Down Expand Up @@ -146,11 +146,26 @@ NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect() const
return nms_results;
}

void NeuralNetworkDetector::add_results(ResultsVec results, const std::vector<size_t>& expected)
void NeuralNetworkDetector::add_results(ResultsVec results, const std::vector<size_t>& expected, const std::vector<double>& thresholds)
{
std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) {
return std::ranges::find(expected, res.cls_index) != expected.end();
});
if (expected.size() != thresholds.size()) {
LogError << name_ << VAR(uid_) << "expected.size() != thresholds.size()" << VAR(expected) << VAR(thresholds);
return;
}

for (size_t i = 0; i != expected.size(); ++i) {
size_t exp = expected.at(i);
auto it = std::ranges::find(results, exp, std::mem_fn(&Result::cls_index));
if (it == results.end()) {
continue;
}
const Result& res = *it;
double thres = thresholds.at(i);
if (res.score < thres) {
continue;
}
filtered_results_.emplace_back(res);
}

merge_vector_(all_results_, std::move(results));
}
Expand Down
2 changes: 1 addition & 1 deletion source/MaaFramework/Vision/NeuralNetworkDetector.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class NeuralNetworkDetector

ResultsVec detect() const;

void add_results(ResultsVec results, const std::vector<size_t>& expected);
void add_results(ResultsVec results, const std::vector<size_t>& expected, const std::vector<double>& thresholds);
void cherry_pick();

private:
Expand Down

0 comments on commit ef707e4

Please sign in to comment.