Skip to content

Commit cfcf868

Browse files
committed
refactor: nn
1 parent 7f23867 commit cfcf868

File tree

5 files changed

+167
-136
lines changed

5 files changed

+167
-136
lines changed

source/MaaFramework/Instance/InstanceInternalAPI.hpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@ class ControllerAgent;
1717
MAA_CTRL_NS_END
1818

1919
MAA_NS_BEGIN
20-
class InstanceStatus;
21-
MAA_NS_END
2220

23-
MAA_VISION_NS_BEGIN
24-
class CustomRecognizer;
25-
using CustomRecognizerPtr = std::shared_ptr<CustomRecognizer>;
26-
MAA_VISION_NS_END
21+
class InstanceStatus;
2722

28-
MAA_TASK_NS_BEGIN
29-
class CustomAction;
30-
using CustomActionPtr = std::shared_ptr<CustomAction>;
31-
MAA_TASK_NS_END
23+
struct CustomRecognizerSession
24+
{
25+
MaaCustomRecognizerHandle recognizer_ = nullptr;
26+
MaaTransparentArg recognizer_arg_ = nullptr;
27+
InstanceInternalAPI* inst_ = nullptr;
28+
};
3229

33-
MAA_NS_BEGIN
30+
struct CustomActionSession
31+
{
32+
MaaCustomActionHandle action_ = nullptr;
33+
MaaTransparentArg action_arg_ = nullptr;
34+
InstanceInternalAPI* inst_ = nullptr;
35+
};
3436

3537
struct InstanceInternalAPI : public NonCopyable
3638
{
@@ -39,8 +41,8 @@ struct InstanceInternalAPI : public NonCopyable
3941
virtual MAA_CTRL_NS::ControllerAgent* inter_controller() = 0;
4042
virtual InstanceStatus* inter_status() = 0;
4143
virtual void notify(std::string_view msg, const json::value& details = json::value()) = 0;
42-
virtual MAA_VISION_NS::CustomRecognizerPtr custom_recognizer(const std::string& name) = 0;
43-
virtual MAA_TASK_NS::CustomActionPtr custom_action(const std::string& name) = 0;
44+
virtual CustomRecognizerSession custom_recognizer(const std::string& name) = 0;
45+
virtual CustomActionSession custom_action(const std::string& name) = 0;
4446
};
4547

46-
MAA_NS_END
48+
MAA_NS_END

source/MaaFramework/Vision/NeuralNetworkClassifier.cpp

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,56 +8,63 @@
88

99
MAA_VISION_NS_BEGIN
1010

11-
std::pair<NeuralNetworkClassifier::ResultsVec, size_t> NeuralNetworkClassifier::analyze() const
11+
NeuralNetworkClassifier::NeuralNetworkClassifier(
12+
cv::Mat image,
13+
NeuralNetworkClassifierParam param,
14+
std::shared_ptr<Ort::Session> session,
15+
std::string name)
16+
: VisionBase(std::move(image), std::move(name))
17+
, param_(std::move(param))
18+
, session_(std::move(session))
19+
{
20+
analyze();
21+
}
22+
23+
void NeuralNetworkClassifier::analyze()
1224
{
1325
LogFunc << name_;
1426

1527
if (!session_) {
1628
LogError << "OrtSession not loaded";
17-
return {};
29+
return;
1830
}
1931
if (param_.cls_size == 0) {
2032
LogError << "cls_size == 0";
21-
return {};
33+
return;
2234
}
2335
if (param_.cls_size != param_.labels.size()) {
2436
LogError << "cls_size != labels.size()" << VAR(param_.cls_size)
2537
<< VAR(param_.labels.size());
26-
return {};
38+
return;
2739
}
2840

2941
auto start_time = std::chrono::steady_clock::now();
30-
ResultsVec results = foreach_rois();
31-
auto cost = duration_since(start_time);
32-
LogTrace << name_ << "Raw:" << VAR(results) << VAR(cost);
3342

34-
const auto& expected = param_.expected;
35-
filter(results, expected);
43+
auto results = classify_all_rois();
44+
add_results(std::move(results), param_.expected);
3645

37-
cost = duration_since(start_time);
38-
LogTrace << name_ << "Filter:" << VAR(results) << VAR(expected) << VAR(cost);
46+
sort();
3947

40-
sort(results);
41-
size_t index = preferred_index(results);
42-
return { results, index };
48+
auto cost = duration_since(start_time);
49+
LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost);
4350
}
4451

45-
NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::foreach_rois() const
52+
NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::classify_all_rois()
4653
{
4754
if (param_.roi.empty()) {
4855
return { classify(cv::Rect(0, 0, image_.cols, image_.rows)) };
4956
}
50-
51-
ResultsVec results;
52-
for (const cv::Rect& roi : param_.roi) {
53-
Result res = classify(roi);
54-
results.emplace_back(std::move(res));
57+
else {
58+
ResultsVec results;
59+
for (const cv::Rect& roi : param_.roi) {
60+
Result res = classify(roi);
61+
results.emplace_back(std::move(res));
62+
}
63+
return results;
5564
}
56-
57-
return results;
5865
}
5966

60-
NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect& roi) const
67+
NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect& roi)
6168
{
6269
if (!session_) {
6370
LogError << "OrtSession not loaded";
@@ -114,17 +121,33 @@ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect
114121
result.label = param_.labels[result.cls_index];
115122
result.box = roi;
116123

117-
draw_result(result);
124+
if (debug_draw_) {
125+
auto draw = draw_result(result);
126+
handle_draw(draw);
127+
}
118128

119129
return result;
120130
}
121131

122-
void NeuralNetworkClassifier::draw_result(const Result& res) const
132+
void NeuralNetworkClassifier::add_results(ResultsVec results, const std::vector<size_t>& expected)
123133
{
124-
if (!debug_draw_) {
125-
return;
126-
}
134+
std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) {
135+
return std::ranges::find(expected, res.cls_index) != expected.end();
136+
});
137+
138+
merge_vector_(all_results_, std::move(results));
139+
}
140+
141+
void NeuralNetworkClassifier::sort()
142+
{
143+
sort_(all_results_);
144+
sort_(filtered_results_);
145+
146+
handle_index(filtered_results_.size(), param_.result_index);
147+
}
127148

149+
cv::Mat NeuralNetworkClassifier::draw_result(const Result& res) const
150+
{
128151
cv::Mat image_draw = draw_roi(res.box);
129152
cv::Point pt(res.box.x + res.box.width + 5, res.box.y + 20);
130153

@@ -140,21 +163,10 @@ void NeuralNetworkClassifier::draw_result(const Result& res) const
140163
pt.y += 20;
141164
}
142165

143-
handle_draw(image_draw);
144-
}
145-
146-
void NeuralNetworkClassifier::filter(ResultsVec& results, const std::vector<size_t>& expected) const
147-
{
148-
if (expected.empty()) {
149-
return;
150-
}
151-
152-
std::erase_if(results, [&](const Result& res) {
153-
return std::find(expected.begin(), expected.end(), res.cls_index) == expected.end();
154-
});
166+
return image_draw;
155167
}
156168

157-
void NeuralNetworkClassifier::sort(ResultsVec& results) const
169+
void NeuralNetworkClassifier::sort_(ResultsVec& results) const
158170
{
159171
switch (param_.order_by) {
160172
case ResultOrderBy::Horizontal:
@@ -178,14 +190,4 @@ void NeuralNetworkClassifier::sort(ResultsVec& results) const
178190
}
179191
}
180192

181-
size_t NeuralNetworkClassifier::preferred_index(const ResultsVec& results) const
182-
{
183-
auto index_opt = pythonic_index(results.size(), param_.result_index);
184-
if (!index_opt) {
185-
return SIZE_MAX;
186-
}
187-
188-
return *index_opt;
189-
}
190-
191-
MAA_VISION_NS_END
193+
MAA_VISION_NS_END

source/MaaFramework/Vision/NeuralNetworkClassifier.h

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#pragma once
23

34
#include <ostream>
45
#include <vector>
@@ -29,23 +30,36 @@ class NeuralNetworkClassifier : public VisionBase
2930
using ResultsVec = std::vector<Result>;
3031

3132
public:
32-
void set_param(NeuralNetworkClassifierParam param) { param_ = std::move(param); }
33+
NeuralNetworkClassifier(
34+
cv::Mat image,
35+
NeuralNetworkClassifierParam param,
36+
std::shared_ptr<Ort::Session> session,
37+
std::string name = "");
3338

34-
void set_session(std::shared_ptr<Ort::Session> session) { session_ = std::move(session); }
39+
const ResultsVec& all_results() const { return all_results_; }
3540

36-
std::pair<ResultsVec, size_t> analyze() const;
41+
const ResultsVec& filtered_results() const { return filtered_results_; }
3742

3843
private:
39-
ResultsVec foreach_rois() const;
40-
Result classify(const cv::Rect& roi) const;
41-
void draw_result(const Result& res) const;
44+
void analyze();
45+
46+
ResultsVec classify_all_rois();
47+
Result classify(const cv::Rect& roi);
4248

43-
void filter(ResultsVec& results, const std::vector<size_t>& expected) const;
44-
void sort(ResultsVec& results) const;
45-
size_t preferred_index(const ResultsVec& results) const;
49+
void add_results(ResultsVec results, const std::vector<size_t>& expected);
50+
void sort();
51+
52+
private:
53+
cv::Mat draw_result(const Result& res) const;
54+
void sort_(ResultsVec& results) const;
4655

47-
NeuralNetworkClassifierParam param_;
56+
private:
57+
const NeuralNetworkClassifierParam param_;
4858
std::shared_ptr<Ort::Session> session_ = nullptr;
59+
60+
private:
61+
ResultsVec all_results_;
62+
ResultsVec filtered_results_;
4963
};
5064

5165
MAA_VISION_NS_END

0 commit comments

Comments
 (0)