diff --git a/docs/en_us/3.1-PipelineProtocol.md b/docs/en_us/3.1-PipelineProtocol.md index d66d57b0e..c26e8e6ff 100644 --- a/docs/en_us/3.1-PipelineProtocol.md +++ b/docs/en_us/3.1-PipelineProtocol.md @@ -252,9 +252,8 @@ This task property requires additional fields: Recognition area coordinates. Optional, default is [0, 0, 0, 0], which represents the full screen. The four values are [x, y, w, h]. -- `template`: *string* - Path to the template image, relative to the "image" folder. Required. - Currently, only a single image is supported. +- `template`: *string* | *list* + Path to the template image, relative to the "image" folder. Required. - `count`: *int* The number of required matching feature points (threshold), default is 4. diff --git "a/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" "b/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" index ff1e26dd8..39be05c53 100644 --- "a/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" +++ "b/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" @@ -257,9 +257,8 @@ graph LR; 识别区域坐标。可选,默认 [0, 0, 0, 0],即全屏。 四个值分别为 [x, y, w, h]。 -- `template`: *string* +- `template`: *string* | *list* 模板图片路径,需要 `image` 文件夹的相对路径。必选。 - 目前仅支持单张图片。 - `count`: *int* 匹配的特征点的数量要求(阈值),默认 4. diff --git a/source/MaaFramework/Instance/InstanceInternalAPI.hpp b/source/MaaFramework/Instance/InstanceInternalAPI.hpp index 6e08f511d..7dafc2835 100644 --- a/source/MaaFramework/Instance/InstanceInternalAPI.hpp +++ b/source/MaaFramework/Instance/InstanceInternalAPI.hpp @@ -6,6 +6,8 @@ #include #include "Conf/Conf.h" +#include "MaaFramework/Task/MaaCustomAction.h" +#include "MaaFramework/Task/MaaCustomRecognizer.h" #include "Utils/NonCopyable.hpp" MAA_RES_NS_BEGIN @@ -17,20 +19,20 @@ class ControllerAgent; MAA_CTRL_NS_END MAA_NS_BEGIN -class InstanceStatus; -MAA_NS_END -MAA_VISION_NS_BEGIN -class CustomRecognizer; -using CustomRecognizerPtr = std::shared_ptr; -MAA_VISION_NS_END +class InstanceStatus; -MAA_TASK_NS_BEGIN -class CustomAction; -using CustomActionPtr = std::shared_ptr; -MAA_TASK_NS_END +struct CustomRecognizerSession +{ + MaaCustomRecognizerHandle recognizer = nullptr; + MaaTransparentArg recognizer_arg = nullptr; +}; -MAA_NS_BEGIN +struct CustomActionSession +{ + MaaCustomActionHandle action = nullptr; + MaaTransparentArg action_arg = nullptr; +}; struct InstanceInternalAPI : public NonCopyable { @@ -39,8 +41,8 @@ struct InstanceInternalAPI : public NonCopyable virtual MAA_CTRL_NS::ControllerAgent* inter_controller() = 0; virtual InstanceStatus* inter_status() = 0; virtual void notify(std::string_view msg, const json::value& details = json::value()) = 0; - virtual MAA_VISION_NS::CustomRecognizerPtr custom_recognizer(const std::string& name) = 0; - virtual MAA_TASK_NS::CustomActionPtr custom_action(const std::string& name) = 0; + virtual CustomRecognizerSession* custom_recognizer_session(const std::string& name) = 0; + virtual CustomActionSession* custom_action_session(const std::string& name) = 0; }; -MAA_NS_END \ No newline at end of file +MAA_NS_END diff --git a/source/MaaFramework/Instance/InstanceMgr.cpp b/source/MaaFramework/Instance/InstanceMgr.cpp index 87aa24216..c16a989d9 100644 --- a/source/MaaFramework/Instance/InstanceMgr.cpp +++ b/source/MaaFramework/Instance/InstanceMgr.cpp @@ -140,21 +140,20 @@ bool InstanceMgr::register_custom_recognizer( return false; } - auto recognizer_ptr = - std::make_shared(handle, handle_arg, this); - return custom_recognizers_.insert_or_assign(std::move(name), std::move(recognizer_ptr)).second; + CustomRecognizerSession session { handle, handle_arg }; + return custom_recognizer_sessions_.insert_or_assign(std::move(name), std::move(session)).second; } bool InstanceMgr::unregister_custom_recognizer(std::string name) { LogInfo << VAR(name); - return custom_recognizers_.erase(name) > 0; + return custom_recognizer_sessions_.erase(name) > 0; } void InstanceMgr::clear_custom_recognizer() { LogInfo; - custom_recognizers_.clear(); + custom_recognizer_sessions_.clear(); } bool InstanceMgr::register_custom_action( @@ -167,20 +166,20 @@ bool InstanceMgr::register_custom_action( LogError << "Invalid handle"; return false; } - auto action_ptr = std::make_shared(handle, handle_arg, this); - return custom_actions_.insert_or_assign(std::move(name), std::move(action_ptr)).second; + CustomActionSession session { handle, handle_arg }; + return custom_action_sessions_.insert_or_assign(std::move(name), std::move(session)).second; } bool InstanceMgr::unregister_custom_action(std::string name) { LogInfo << VAR(name); - return custom_actions_.erase(name) > 0; + return custom_action_sessions_.erase(name) > 0; } void InstanceMgr::clear_custom_action() { LogInfo; - custom_actions_.clear(); + custom_action_sessions_.clear(); } MaaStatus InstanceMgr::task_status(MaaTaskId task_id) const @@ -263,24 +262,24 @@ void InstanceMgr::notify(std::string_view msg, const json::value& details) notifier.notify(msg, details); } -MAA_VISION_NS::CustomRecognizerPtr InstanceMgr::custom_recognizer(const std::string& name) +CustomRecognizerSession* InstanceMgr::custom_recognizer_session(const std::string& name) { - auto it = custom_recognizers_.find(name); - if (it == custom_recognizers_.end()) { + auto it = custom_recognizer_sessions_.find(name); + if (it == custom_recognizer_sessions_.end()) { LogError << "Custom recognizer not found:" << name; return nullptr; } - return it->second; + return &it->second; } -MAA_TASK_NS::CustomActionPtr InstanceMgr::custom_action(const std::string& name) +CustomActionSession* InstanceMgr::custom_action_session(const std::string& name) { - auto it = custom_actions_.find(name); - if (it == custom_actions_.end()) { + auto it = custom_action_sessions_.find(name); + if (it == custom_action_sessions_.end()) { LogError << "Custom action not found:" << name; return nullptr; } - return it->second; + return &it->second; } bool InstanceMgr::run_task(TaskId id, TaskPtr task_ptr) diff --git a/source/MaaFramework/Instance/InstanceMgr.h b/source/MaaFramework/Instance/InstanceMgr.h index 6a1969079..e99b393d0 100644 --- a/source/MaaFramework/Instance/InstanceMgr.h +++ b/source/MaaFramework/Instance/InstanceMgr.h @@ -57,8 +57,8 @@ class InstanceMgr virtual MAA_CTRL_NS::ControllerAgent* inter_controller() override; virtual InstanceStatus* inter_status() override; virtual void notify(std::string_view msg, const json::value& details = json::value()) override; - virtual MAA_VISION_NS::CustomRecognizerPtr custom_recognizer(const std::string& name) override; - virtual MAA_TASK_NS::CustomActionPtr custom_action(const std::string& name) override; + virtual CustomRecognizerSession* custom_recognizer_session(const std::string& name) override; + virtual CustomActionSession* custom_action_session(const std::string& name) override; private: using TaskPtr = std::shared_ptr; @@ -73,8 +73,8 @@ class InstanceMgr InstanceStatus status_; bool need_to_stop_ = false; - std::unordered_map custom_recognizers_; - std::unordered_map custom_actions_; + std::unordered_map custom_recognizer_sessions_; + std::unordered_map custom_action_sessions_; std::unique_ptr> task_runner_ = nullptr; MessageNotifier notifier; diff --git a/source/MaaFramework/Resource/PipelineResMgr.cpp b/source/MaaFramework/Resource/PipelineResMgr.cpp index 78e64568c..4df7aaef4 100644 --- a/source/MaaFramework/Resource/PipelineResMgr.cpp +++ b/source/MaaFramework/Resource/PipelineResMgr.cpp @@ -599,12 +599,16 @@ bool PipelineResMgr::parse_feature_matcher_param( return false; } - if (!get_and_check_value( + if (!get_and_check_value_or_array( input, "template", - output.template_path, - default_value.template_path)) { - LogError << "failed to get_and_check_value template_path" << VAR(input); + output.template_paths, + default_value.template_paths)) { + LogError << "failed to get_and_check_value_or_array templates" << VAR(input); + return false; + } + if (output.template_paths.empty()) { + LogError << "templates is empty" << VAR(input); return false; } diff --git a/source/MaaFramework/Task/Actuator.cpp b/source/MaaFramework/Task/Actuator.cpp index e4f505738..198396b40 100644 --- a/source/MaaFramework/Task/Actuator.cpp +++ b/source/MaaFramework/Task/Actuator.cpp @@ -128,15 +128,13 @@ void Actuator::wait_freezes(const MAA_RES_NS::WaitFreezesParam& param, const cv: LogFunc << "Wait freezes:" << VAR(param.time) << VAR(param.threshold) << VAR(param.method); cv::Rect target = get_target_rect(param.target, cur_box); + cv::Mat pre_image = controller()->screencap(); - TemplateComparator comp; - comp.set_param({ + TemplateComparatorParam comp_param { .roi = { target }, .threshold = param.threshold, .method = param.method, - }); - - cv::Mat pre_image = controller()->screencap(); + }; auto pre_time = std::chrono::steady_clock::now(); @@ -148,7 +146,9 @@ void Actuator::wait_freezes(const MAA_RES_NS::WaitFreezesParam& param, const cv: break; } - auto ret = comp.analyze(pre_image, cur_image); + TemplateComparator comparator(pre_image, cur_image, comp_param); + + auto ret = comparator.filtered_results(); if (ret.empty()) { pre_image = cur_image; pre_time = std::chrono::steady_clock::now(); @@ -199,13 +199,13 @@ bool Actuator::custom_action( LogError << "Inst is null"; return false; } - auto action = inst_->custom_action(param.name); - if (!action) { + auto* session = inst_->custom_action_session(param.name); + if (!session) { LogError << "Custom task not found" << VAR(param.name); return false; } - return action->run(task_name, param, cur_box, cur_rec_detail); + return CustomAction(*session, inst_).run(task_name, param, cur_box, cur_rec_detail); } cv::Rect Actuator::get_target_rect(const MAA_RES_NS::Action::Target target, const cv::Rect& cur_box) @@ -264,4 +264,4 @@ void Actuator::sleep(std::chrono::milliseconds ms) const LogTrace << "end of sleep" << ms << VAR(interval); } -MAA_TASK_NS_END \ No newline at end of file +MAA_TASK_NS_END diff --git a/source/MaaFramework/Task/CustomAction.cpp b/source/MaaFramework/Task/CustomAction.cpp index d69c2de44..3ac8b7777 100644 --- a/source/MaaFramework/Task/CustomAction.cpp +++ b/source/MaaFramework/Task/CustomAction.cpp @@ -8,13 +8,8 @@ MAA_TASK_NS_BEGIN -CustomAction::CustomAction( - MaaCustomActionHandle handle, - MaaTransparentArg handle_arg_, - InstanceInternalAPI* inst) - : action_(handle) - , action_arg_(handle_arg_) - , inst_(inst) +CustomAction::CustomAction(CustomActionSession session, InstanceInternalAPI* inst) + : session_(std::move(session)), inst_(inst) { } @@ -24,10 +19,11 @@ bool CustomAction::run( const cv::Rect& cur_box, const json::value& cur_rec_detail) { - LogFunc << VAR(task_name) << VAR_VOIDP(action_) << VAR(param.custom_param) << VAR(cur_box); + LogFunc << VAR(task_name) << VAR_VOIDP(session_.action) << VAR(param.custom_param) + << VAR(cur_box); - if (!action_ || !action_->run) { - LogError << "Action is null" << VAR_VOIDP(action_) << VAR_VOIDP(action_->run); + if (!session_.action || !session_.action->run) { + LogError << "Action is null" << VAR(task_name); return false; } @@ -39,16 +35,16 @@ bool CustomAction::run( .height = cur_box.height }; std::string cur_rec_detail_string = cur_rec_detail.to_string(); - bool ret = action_->run( + bool ret = session_.action->run( &sync_ctx, task_name.c_str(), custom_param_string.c_str(), &box, cur_rec_detail_string.c_str(), - action_arg_); - LogTrace << VAR_VOIDP(action_) << VAR_VOIDP(action_->run) << VAR(ret); + session_.action_arg); + LogTrace << VAR_VOIDP(session_.action) << VAR_VOIDP(session_.action->run) << VAR(ret); return ret; } -MAA_TASK_NS_END \ No newline at end of file +MAA_TASK_NS_END diff --git a/source/MaaFramework/Task/CustomAction.h b/source/MaaFramework/Task/CustomAction.h index be0a26554..ca9820e9a 100644 --- a/source/MaaFramework/Task/CustomAction.h +++ b/source/MaaFramework/Task/CustomAction.h @@ -11,10 +11,7 @@ MAA_TASK_NS_BEGIN class CustomAction { public: - CustomAction( - MaaCustomActionHandle handle, - MaaTransparentArg handle_arg, - InstanceInternalAPI* inst); + CustomAction(CustomActionSession session, InstanceInternalAPI* inst); bool run(const std::string& task_name, @@ -23,8 +20,7 @@ class CustomAction const json::value& cur_rec_detail); private: - MaaCustomActionHandle action_ = nullptr; - MaaTransparentArg action_arg_ = nullptr; + CustomActionSession session_; InstanceInternalAPI* inst_ = nullptr; }; diff --git a/source/MaaFramework/Task/Recognizer.cpp b/source/MaaFramework/Task/Recognizer.cpp index e230116bd..48956c91d 100644 --- a/source/MaaFramework/Task/Recognizer.cpp +++ b/source/MaaFramework/Task/Recognizer.cpp @@ -61,14 +61,14 @@ std::optional break; case Type::NeuralNetworkClassify: - result = classify( + result = nn_classify( image, std::get(task_data.rec_param), task_data.name); break; case Type::NeuralNetworkDetect: - result = detect( + result = nn_detect( image, std::get(task_data.rec_param), task_data.name); @@ -120,11 +120,6 @@ std::optional Recognizer::template_match( return std::nullopt; } - TemplateMatcher matcher; - matcher.set_image(image); - matcher.set_name(name); - matcher.set_param(param); - std::vector> templates; for (const auto& path : param.template_paths) { auto templ = resource()->template_res().image(path); @@ -134,15 +129,19 @@ std::optional Recognizer::template_match( } templates.emplace_back(std::move(templ)); } - matcher.set_templates(std::move(templates)); - auto [results, index] = matcher.analyze(); + TemplateMatcher matcher(image, param, templates, name); + + auto results = std::move(matcher).filtered_results(); + size_t index = matcher.preferred_index(); + auto draws = std::move(matcher).draws(); + if (index >= results.size()) { return std::nullopt; } - const cv::Rect& box = results[index].box; - return Result { .box = box, .detail = std::move(results) }; + + return Result { .box = box, .detail = std::move(results), .draws = std::move(draws) }; } std::optional Recognizer::feature_match( @@ -157,21 +156,28 @@ std::optional Recognizer::feature_match( return std::nullopt; } - FeatureMatcher matcher; - matcher.set_image(image); - matcher.set_name(name); - matcher.set_param(param); + std::vector> templates; + for (const auto& path : param.template_paths) { + auto templ = resource()->template_res().image(path); + if (!templ) { + LogWarn << "Template not found:" << path; + continue; + } + templates.emplace_back(std::move(templ)); + } + + FeatureMatcher matcher(image, param, templates, name); - std::shared_ptr templ = resource()->template_res().image(param.template_path); - matcher.set_template(std::move(templ)); + auto results = std::move(matcher).filtered_results(); + size_t index = matcher.preferred_index(); + auto draws = std::move(matcher).draws(); - auto [results, index] = matcher.analyze(); if (index >= results.size()) { return std::nullopt; } - const cv::Rect& box = results[index].box; - return Result { .box = box, .detail = std::move(results) }; + + return Result { .box = box, .detail = std::move(results), .draws = std::move(draws) }; } std::optional Recognizer::color_match( @@ -186,18 +192,18 @@ std::optional Recognizer::color_match( return std::nullopt; } - ColorMatcher matcher; - matcher.set_image(image); - matcher.set_name(name); - matcher.set_param(param); + ColorMatcher matcher(image, param, name); + + auto results = std::move(matcher).filtered_results(); + size_t index = matcher.preferred_index(); + auto draws = std::move(matcher).draws(); - auto [results, index] = matcher.analyze(); if (index >= results.size()) { return std::nullopt; } - const cv::Rect& box = results[index].box; - return Result { .box = box, .detail = std::move(results) }; + + return Result { .box = box, .detail = std::move(results), .draws = std::move(draws) }; } std::optional Recognizer::ocr( @@ -212,28 +218,25 @@ std::optional Recognizer::ocr( return std::nullopt; } - OCRer ocrer; - ocrer.set_image(image); - ocrer.set_name(name); - ocrer.set_param(param); - auto det_session = resource()->ocr_res().deter(param.model); auto rec_session = resource()->ocr_res().recer(param.model); auto ocr_session = resource()->ocr_res().ocrer(param.model); - ocrer.set_session(std::move(det_session), std::move(rec_session), std::move(ocr_session)); - ocrer.set_status(status()); + OCRer ocrer(image, param, det_session, rec_session, ocr_session, status(), name); + + auto results = std::move(ocrer).filtered_results(); + size_t index = ocrer.preferred_index(); + auto draws = std::move(ocrer).draws(); - auto [results, index] = ocrer.analyze(); if (index >= results.size()) { return std::nullopt; } - const cv::Rect& box = results[index].box; - return Result { .box = box, .detail = std::move(results) }; + + return Result { .box = box, .detail = std::move(results), .draws = std::move(draws) }; } -std::optional Recognizer::classify( +std::optional Recognizer::nn_classify( const cv::Mat& image, const MAA_VISION_NS::NeuralNetworkClassifierParam& param, const std::string& name) @@ -245,24 +248,23 @@ std::optional Recognizer::classify( return std::nullopt; } - NeuralNetworkClassifier classifier; - classifier.set_image(image); - classifier.set_name(name); - classifier.set_param(param); - auto session = resource()->onnx_res().classifier(param.model); - classifier.set_session(std::move(session)); - auto [results, index] = classifier.analyze(); + NeuralNetworkClassifier classifier(image, param, session, name); + + auto results = std::move(classifier).filtered_results(); + size_t index = classifier.preferred_index(); + auto draws = std::move(classifier).draws(); + if (index >= results.size()) { return std::nullopt; } - const cv::Rect& box = results[index].box; - return Result { .box = box, .detail = std::move(results) }; + + return Result { .box = box, .detail = std::move(results), .draws = std::move(draws) }; } -std::optional Recognizer::detect( +std::optional Recognizer::nn_detect( const cv::Mat& image, const MAA_VISION_NS::NeuralNetworkDetectorParam& param, const std::string& name) @@ -274,21 +276,20 @@ std::optional Recognizer::detect( return std::nullopt; } - NeuralNetworkDetector detector; - detector.set_image(image); - detector.set_name(name); - detector.set_param(param); - auto session = resource()->onnx_res().detector(param.model); - detector.set_session(std::move(session)); - auto [results, index] = detector.analyze(); + NeuralNetworkDetector detector(image, param, session, name); + + auto results = std::move(detector).filtered_results(); + size_t index = detector.preferred_index(); + auto draws = std::move(detector).draws(); + if (index >= results.size()) { return std::nullopt; } - const cv::Rect& box = results[index].box; - return Result { .box = box, .detail = std::move(results) }; + + return Result { .box = box, .detail = std::move(results), .draws = std::move(draws) }; } std::optional Recognizer::custom_recognize( @@ -303,22 +304,22 @@ std::optional Recognizer::custom_recognize( return std::nullopt; } - auto recognizer = inst_->custom_recognizer(param.name); - if (!recognizer) { + auto* session = inst_->custom_recognizer_session(param.name); + if (!session) { LogError << "Custom recognizer not found:" << param.name; return std::nullopt; } - recognizer->set_image(image); - recognizer->set_param(param); - recognizer->set_name(name); - auto result_opt = recognizer->analyze(); - if (!result_opt) { + CustomRecognizer recognizer(image, param, *session, inst_, name); + auto results = std::move(recognizer).result(); + bool ret = recognizer.ret(); + + if (!ret) { return std::nullopt; } - const cv::Rect& box = result_opt->box; - return Result { .box = box, .detail = std::move(*result_opt) }; + const cv::Rect& box = results.box; + return Result { .box = box, .detail = std::move(results) }; } void Recognizer::show_hit_draw( @@ -342,4 +343,4 @@ void Recognizer::show_hit_draw( cv::destroyWindow(kWinName); } -MAA_TASK_NS_END \ No newline at end of file +MAA_TASK_NS_END diff --git a/source/MaaFramework/Task/Recognizer.h b/source/MaaFramework/Task/Recognizer.h index a21160573..e7395ec42 100644 --- a/source/MaaFramework/Task/Recognizer.h +++ b/source/MaaFramework/Task/Recognizer.h @@ -22,6 +22,7 @@ class Recognizer { cv::Rect box {}; json::value detail; + std::vector draws; }; public: @@ -46,11 +47,11 @@ class Recognizer const std::string& name); std::optional ocr(const cv::Mat& image, const MAA_VISION_NS::OCRerParam& param, const std::string& name); - std::optional classify( + std::optional nn_classify( const cv::Mat& image, const MAA_VISION_NS::NeuralNetworkClassifierParam& param, const std::string& name); - std::optional detect( + std::optional nn_detect( const cv::Mat& image, const MAA_VISION_NS::NeuralNetworkDetectorParam& param, const std::string& name); diff --git a/source/MaaFramework/Vision/ColorMatcher.cpp b/source/MaaFramework/Vision/ColorMatcher.cpp index 7a2ad1bca..15a632098 100644 --- a/source/MaaFramework/Vision/ColorMatcher.cpp +++ b/source/MaaFramework/Vision/ColorMatcher.cpp @@ -8,62 +8,45 @@ MAA_VISION_NS_BEGIN -std::pair ColorMatcher::analyze() const +ColorMatcher::ColorMatcher(cv::Mat image, ColorMatcherParam param, std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) { - ResultsVec all_results; - - for (const auto& range : param_.range) { - auto start_time = std::chrono::steady_clock::now(); - - bool connected = param_.connected; - ResultsVec results = foreach_rois(range, connected); - - auto cost = duration_since(start_time); - LogTrace << name_ << "Raw:" << VAR(results) << VAR(range.first) << VAR(range.second) - << VAR(connected) << VAR(cost); - - int count = param_.count; - filter(results, count); + analyze(); +} - cost = duration_since(start_time); - LogTrace << name_ << "Filter:" << VAR(results) << VAR(range.first) << VAR(range.second) - << VAR(count) << VAR(connected) << VAR(cost); +void ColorMatcher::analyze() +{ + auto start_time = std::chrono::steady_clock::now(); - all_results.insert( - all_results.end(), - std::make_move_iterator(results.begin()), - std::make_move_iterator(results.end())); + for (const auto& range : param_.range) { + auto results = match_all_rois(range); + add_results(std::move(results), param_.count); } - sort(all_results); - size_t index = preferred_index(all_results); + sort(); - return { all_results, index }; + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -ColorMatcher::ResultsVec - ColorMatcher::foreach_rois(const ColorMatcherParam::Range& range, bool connected) const +ColorMatcher::ResultsVec ColorMatcher::match_all_rois(const ColorMatcherParam::Range& range) { if (param_.roi.empty()) { - return { color_match(cv::Rect(0, 0, image_.cols, image_.rows), range, connected) }; + return color_match(cv::Rect(0, 0, image_.cols, image_.rows), range); } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - ResultsVec res = color_match(roi, range, connected); - results.insert( - results.end(), - std::make_move_iterator(res.begin()), - std::make_move_iterator(res.end())); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + auto res = color_match(roi, range); + merge_vector_(results, std::move(res)); + } + return results; } - - return results; } -ColorMatcher::ResultsVec ColorMatcher::color_match( - const cv::Rect& roi, - const ColorMatcherParam::Range& range, - bool connected) const +ColorMatcher::ResultsVec + ColorMatcher::color_match(const cv::Rect& roi, const ColorMatcherParam::Range& range) { cv::Mat image = image_with_roi(roi); cv::Mat color; @@ -71,13 +54,34 @@ ColorMatcher::ResultsVec ColorMatcher::color_match( cv::Mat bin; cv::inRange(color, range.first, range.second, bin); - ResultsVec results = - connected ? count_non_zero_with_connected(bin, roi.tl()) : count_non_zero(bin, roi.tl()); + ResultsVec results = param_.connected ? count_non_zero_with_connected(bin, roi.tl()) + : count_non_zero(bin, roi.tl()); + + if (debug_draw_) { + auto draw = draw_result(roi, color, bin, results); + handle_draw(draw); + } - draw_result(roi, color, bin, results); return results; } +void ColorMatcher::add_results(ResultsVec results, int count) +{ + std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { + return res.count >= count; + }); + + merge_vector_(all_results_, std::move(results)); +} + +void ColorMatcher::sort() +{ + sort_(all_results_); + sort_(filtered_results_); + + handle_index(filtered_results_.size(), param_.result_index); +} + ColorMatcher::ResultsVec ColorMatcher::count_non_zero(const cv::Mat& bin, const cv::Point& tl) const { int count = cv::countNonZero(bin); @@ -114,16 +118,12 @@ ColorMatcher::ResultsVec return NMS_for_count(std::move(results), 0.7); } -void ColorMatcher::draw_result( +cv::Mat ColorMatcher::draw_result( const cv::Rect& roi, const cv::Mat& color, const cv::Mat& bin, const ResultsVec& results) const { - if (!debug_draw_) { - return; - } - cv::Mat image_draw = draw_roi(roi); const auto color_draw = cv::Scalar(0, 0, 255); @@ -174,15 +174,10 @@ void ColorMatcher::draw_result( // cv::line(image_draw, cv::Point(raw_width + color.cols, 0), res.box.tl(), color_draw, 1); - handle_draw(image_draw); + return image_draw; } -void ColorMatcher::filter(ResultsVec& results, int count) const -{ - std::erase_if(results, [count](const auto& res) { return res.count < count; }); -} - -void ColorMatcher::sort(ResultsVec& results) const +void ColorMatcher::sort_(ResultsVec& results) const { switch (param_.order_by) { case ResultOrderBy::Horizontal: @@ -206,14 +201,4 @@ void ColorMatcher::sort(ResultsVec& results) const } } -size_t ColorMatcher::preferred_index(const ResultsVec& results) const -{ - auto index_opt = pythonic_index(results.size(), param_.result_index); - if (!index_opt) { - return SIZE_MAX; - } - - return *index_opt; -} - MAA_VISION_NS_END \ No newline at end of file diff --git a/source/MaaFramework/Vision/ColorMatcher.h b/source/MaaFramework/Vision/ColorMatcher.h index fd53aaa58..670a8027f 100644 --- a/source/MaaFramework/Vision/ColorMatcher.h +++ b/source/MaaFramework/Vision/ColorMatcher.h @@ -20,29 +20,41 @@ class ColorMatcher : public VisionBase using ResultsVec = std::vector; public: - void set_param(ColorMatcherParam param) { param_ = std::move(param); } + ColorMatcher(cv::Mat image, ColorMatcherParam param, std::string name = ""); - std::pair analyze() const; + const ResultsVec& all_results() const& { return all_results_; } + + ResultsVec&& all_results() && { return std::move(all_results_); } + + const ResultsVec& filtered_results() const& { return filtered_results_; } + + ResultsVec filtered_results() && { return std::move(filtered_results_); } + +private: + void analyze(); + ResultsVec match_all_rois(const ColorMatcherParam::Range& range); + ResultsVec color_match(const cv::Rect& roi, const ColorMatcherParam::Range& range); + + void add_results(ResultsVec results, int count); + void sort(); private: - ResultsVec foreach_rois(const ColorMatcherParam::Range& range, bool connected) const; - ResultsVec color_match( - const cv::Rect& roi, - const ColorMatcherParam::Range& range, - bool connected) const; ResultsVec count_non_zero(const cv::Mat& bin, const cv::Point& tl) const; ResultsVec count_non_zero_with_connected(const cv::Mat& bin, const cv::Point& tl) const; - void draw_result( + cv::Mat draw_result( const cv::Rect& roi, const cv::Mat& color, const cv::Mat& bin, const ResultsVec& results) const; - void filter(ResultsVec& results, int count) const; - void sort(ResultsVec& results) const; - size_t preferred_index(const ResultsVec& results) const; + void sort_(ResultsVec& results) const; - ColorMatcherParam param_; +private: + const ColorMatcherParam param_; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/CustomRecognizer.cpp b/source/MaaFramework/Vision/CustomRecognizer.cpp index ada3c7960..3ff20759a 100644 --- a/source/MaaFramework/Vision/CustomRecognizer.cpp +++ b/source/MaaFramework/Vision/CustomRecognizer.cpp @@ -11,24 +11,26 @@ MAA_VISION_NS_BEGIN CustomRecognizer::CustomRecognizer( - MaaCustomRecognizerHandle handle, - MaaTransparentArg handle_arg, - InstanceInternalAPI* inst) - : VisionBase() - , recognizer_(handle) - , recognizer_arg_(handle_arg) + cv::Mat image, + CustomRecognizerParam param, + CustomRecognizerSession session, + InstanceInternalAPI* inst, + std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) + , session_(std::move(session)) , inst_(inst) { + analyze(); } -std::optional CustomRecognizer::analyze() const +void CustomRecognizer::analyze() { - LogFunc << VAR_VOIDP(recognizer_) << VAR_VOIDP(recognizer_->analyze) - << VAR(param_.custom_param); + LogFunc << VAR_VOIDP(session_.recognizer) << VAR(param_.custom_param); - if (!recognizer_ || !recognizer_->analyze) { - LogError << "Recognizer is null"; - return std::nullopt; + if (!session_.recognizer || !session_.recognizer->analyze) { + LogError << "Recognizer is nullptr"; + return; } auto start_time = std::chrono::steady_clock::now(); @@ -43,27 +45,23 @@ std::optional CustomRecognizer::analyze() const MaaRect maa_box { 0 }; StringBuffer detail_buffer; - bool ret = recognizer_->analyze( + ret_ = session_.recognizer->analyze( &sync_ctx, &image_buffer, name_.c_str(), custom_param_str.c_str(), - recognizer_arg_, + session_.recognizer_arg, &maa_box, &detail_buffer); cv::Rect box { maa_box.x, maa_box.y, maa_box.width, maa_box.height }; std::string detail(detail_buffer.data(), detail_buffer.size()); - auto cost = duration_since(start_time); - LogTrace << VAR(ret) << VAR(box) << VAR(detail) << VAR(cost); - - if (!ret) { - return std::nullopt; - } - auto jdetail = json::parse(detail).value_or(detail); - return Result { .box = box, .detail = std::move(jdetail) }; + result_ = Result { .box = box, .detail = std::move(jdetail) }; + + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(ret_) << VAR(result_) << VAR(cost); } -MAA_VISION_NS_END \ No newline at end of file +MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/CustomRecognizer.h b/source/MaaFramework/Vision/CustomRecognizer.h index fa09dd445..e8b50a8a7 100644 --- a/source/MaaFramework/Vision/CustomRecognizer.h +++ b/source/MaaFramework/Vision/CustomRecognizer.h @@ -22,20 +22,29 @@ class CustomRecognizer : public VisionBase public: CustomRecognizer( - MaaCustomRecognizerHandle handle, - MaaTransparentArg handle_arg, - InstanceInternalAPI* inst); + cv::Mat image, + CustomRecognizerParam param, + CustomRecognizerSession session, + InstanceInternalAPI* inst, + std::string name = ""); - void set_param(CustomRecognizerParam param) { param_ = std::move(param); } + bool ret() const { return ret_; } - std::optional analyze() const; + const Result& result() const& { return result_; } + + Result result() && { return std::move(result_); } + +private: + void analyze(); private: - MaaCustomRecognizerHandle recognizer_ = nullptr; - MaaTransparentArg recognizer_arg_ = nullptr; + const CustomRecognizerParam param_; + CustomRecognizerSession session_; InstanceInternalAPI* inst_ = nullptr; - CustomRecognizerParam param_; +private: + bool ret_ = false; + Result result_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/FeatureMatcher.cpp b/source/MaaFramework/Vision/FeatureMatcher.cpp index b1eff4876..f14dc7f2d 100644 --- a/source/MaaFramework/Vision/FeatureMatcher.cpp +++ b/source/MaaFramework/Vision/FeatureMatcher.cpp @@ -3,6 +3,7 @@ MAA_SUPPRESS_CV_WARNINGS_BEGIN #include #include + #ifdef MAA_VISION_HAS_XFEATURES2D #include #endif @@ -13,64 +14,73 @@ MAA_SUPPRESS_CV_WARNINGS_END MAA_VISION_NS_BEGIN -std::pair FeatureMatcher::analyze() const +FeatureMatcher::FeatureMatcher( + cv::Mat image, + FeatureMatcherParam param, + std::vector> templates, + std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) + , templates_(std::move(templates)) { - if (!template_) { - LogError << name_ << "template_ is empty" << VAR(param_.template_path); - return {}; - } + analyze(); +} - const cv::Mat& templ = *template_; +void FeatureMatcher::analyze() +{ + if (templates_.empty()) { + LogError << name_ << "templates is empty" << VAR(param_.template_paths); + return; + } auto start_time = std::chrono::steady_clock::now(); - ResultsVec results = foreach_rois(templ); - auto cost = duration_since(start_time); - LogTrace << name_ << "Raw:" << VAR(results) << VAR(param_.template_path) << VAR(cost); - - int count = param_.count; - filter(results, count); + for (const auto& templ : templates_) { + if (!templ) { + continue; + } - cost = duration_since(start_time); - LogTrace << name_ << "Filter:" << VAR(results) << VAR(param_.template_path) << VAR(count) - << VAR(cost); + auto results = match_all_rois(*templ); + add_results(std::move(results), param_.count); + } - sort(results); - size_t index = preferred_index(results); + sort(); - return { results, index }; + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -FeatureMatcher::ResultsVec FeatureMatcher::foreach_rois(const cv::Mat& templ) const +FeatureMatcher::ResultsVec FeatureMatcher::match_all_rois(const cv::Mat& templ) { if (templ.empty()) { - LogWarn << name_ << "template is empty" << VAR(param_.template_path); + LogWarn << name_ << "template is empty" << VAR(param_.template_paths); return {}; } auto [keypoints_1, descriptors_1] = detect(templ, create_mask(templ, param_.green_mask)); if (param_.roi.empty()) { - cv::Rect roi = cv::Rect(0, 0, image_.cols, image_.rows); - return match_roi(keypoints_1, descriptors_1, roi); + return feature_match( + templ, + keypoints_1, + descriptors_1, + cv::Rect(0, 0, image_.cols, image_.rows)); } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - ResultsVec res = match_roi(keypoints_1, descriptors_1, roi); - results.insert( - results.end(), - std::make_move_iterator(res.begin()), - std::make_move_iterator(res.end())); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + auto res = feature_match(templ, keypoints_1, descriptors_1, roi); + merge_vector_(results, std::move(res)); + } + return results; } - - return results; } -FeatureMatcher::ResultsVec FeatureMatcher::match_roi( +FeatureMatcher::ResultsVec FeatureMatcher::feature_match( + const cv::Mat& templ, const std::vector& keypoints_1, const cv::Mat& descriptors_1, - const cv::Rect& roi_2) const + const cv::Rect& roi_2) { if (roi_2.empty()) { LogError << name_ << "roi_2 is empty"; @@ -81,7 +91,21 @@ FeatureMatcher::ResultsVec FeatureMatcher::match_roi( auto match_points = match(descriptors_1, descriptors_2); - return postproc(match_points, keypoints_1, keypoints_2, roi_2); + std::vector good_matches; + ResultsVec results = feature_postproc( + match_points, + keypoints_1, + keypoints_2, + templ.cols, + templ.rows, + good_matches); + + if (debug_draw_) { + auto draw = draw_result(templ, keypoints_1, roi_2, keypoints_2, good_matches, results); + handle_draw(draw); + } + + return results; } cv::Ptr FeatureMatcher::create_detector() const @@ -167,13 +191,14 @@ std::vector> return match_points; } -FeatureMatcher::ResultsVec FeatureMatcher::postproc( +FeatureMatcher::ResultsVec FeatureMatcher::feature_postproc( const std::vector>& match_points, const std::vector& keypoints_1, const std::vector& keypoints_2, - const cv::Rect& roi_2) const + int templ_cols, + int templ_rows, + std::vector& good_matches) const { - std::vector good_matches; std::vector obj; std::vector scene; @@ -194,57 +219,48 @@ FeatureMatcher::ResultsVec FeatureMatcher::postproc( LogTrace << name_ << "Match:" << VAR(good_matches.size()) << VAR(match_points.size()) << VAR(param_.distance_ratio); - ResultsVec results; - if (good_matches.size() >= 4) { - cv::Mat H = cv::findHomography(obj, scene, cv::RANSAC); - - std::array obj_corners = { cv::Point2d(0, 0), - cv::Point2d(template_->cols, 0), - cv::Point2d(template_->cols, template_->rows), - cv::Point2d(0, template_->rows) }; - std::array scene_corners; - cv::perspectiveTransform(obj_corners, scene_corners, H); - - double x = std::min( - { scene_corners[0].x, scene_corners[1].x, scene_corners[2].x, scene_corners[3].x }); - double y = std::min( - { scene_corners[0].y, scene_corners[1].y, scene_corners[2].y, scene_corners[3].y }); - double w = - std::max( - { scene_corners[0].x, scene_corners[1].x, scene_corners[2].x, scene_corners[3].x }) - - x; - double h = - std::max( - { scene_corners[0].y, scene_corners[1].y, scene_corners[2].y, scene_corners[3].y }) - - y; - cv::Rect box { static_cast(x), - static_cast(y), - static_cast(w), - static_cast(h) }; - - size_t count = - std::ranges::count_if(scene, [&box](const auto& point) { return box.contains(point); }); - - results.emplace_back(Result { .box = box, .count = static_cast(count) }); + if (good_matches.size() < 4) { + return {}; } - draw_result(*template_, keypoints_1, roi_2, keypoints_2, good_matches, results); - - return results; + cv::Mat H = cv::findHomography(obj, scene, cv::RANSAC); + + std::array obj_corners = { cv::Point2d(0, 0), + cv::Point2d(templ_cols, 0), + cv::Point2d(templ_cols, templ_rows), + cv::Point2d(0, templ_rows) }; + std::array scene_corners; + cv::perspectiveTransform(obj_corners, scene_corners, H); + + double x = std::min( + { scene_corners[0].x, scene_corners[1].x, scene_corners[2].x, scene_corners[3].x }); + double y = std::min( + { scene_corners[0].y, scene_corners[1].y, scene_corners[2].y, scene_corners[3].y }); + double w = + std::max({ scene_corners[0].x, scene_corners[1].x, scene_corners[2].x, scene_corners[3].x }) + - x; + double h = + std::max({ scene_corners[0].y, scene_corners[1].y, scene_corners[2].y, scene_corners[3].y }) + - y; + cv::Rect box { static_cast(x), + static_cast(y), + static_cast(w), + static_cast(h) }; + + size_t count = + std::ranges::count_if(scene, [&box](const auto& point) { return box.contains(point); }); + + return { Result { .box = box, .count = static_cast(count) } }; } -void FeatureMatcher::draw_result( +cv::Mat FeatureMatcher::draw_result( const cv::Mat& templ, const std::vector& keypoints_1, const cv::Rect& roi, const std::vector& keypoints_2, const std::vector& good_matches, - ResultsVec& results) const + const ResultsVec& results) const { - if (!debug_draw_) { - return; - } - // const auto color = cv::Scalar(0, 0, 255); cv::Mat matches_draw; cv::drawMatches(image_, keypoints_2, templ, keypoints_1, good_matches, matches_draw); @@ -273,15 +289,27 @@ void FeatureMatcher::draw_result( 1); } - handle_draw(image_draw); + return image_draw; } -void FeatureMatcher::filter(ResultsVec& results, int count) const +void FeatureMatcher::add_results(ResultsVec results, int count) { - std::erase_if(results, [count](const auto& res) { return res.count < count; }); + std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { + return res.count >= count; + }); + + merge_vector_(all_results_, std::move(results)); } -void FeatureMatcher::sort(ResultsVec& results) const +void FeatureMatcher::sort() +{ + sort_(all_results_); + sort_(filtered_results_); + + handle_index(filtered_results_.size(), param_.result_index); +} + +void FeatureMatcher::sort_(ResultsVec& results) const { switch (param_.order_by) { case ResultOrderBy::Horizontal: @@ -305,14 +333,4 @@ void FeatureMatcher::sort(ResultsVec& results) const } } -size_t FeatureMatcher::preferred_index(const ResultsVec& results) const -{ - auto index_opt = pythonic_index(results.size(), param_.result_index); - if (!index_opt) { - return SIZE_MAX; - } - - return *index_opt; -} - MAA_VISION_NS_END \ No newline at end of file diff --git a/source/MaaFramework/Vision/FeatureMatcher.h b/source/MaaFramework/Vision/FeatureMatcher.h index 2b9e1412b..18792e638 100644 --- a/source/MaaFramework/Vision/FeatureMatcher.h +++ b/source/MaaFramework/Vision/FeatureMatcher.h @@ -29,19 +29,33 @@ class FeatureMatcher : public VisionBase using ResultsVec = std::vector; public: - void set_template(std::shared_ptr templ) { template_ = std::move(templ); } + FeatureMatcher( + cv::Mat image, + FeatureMatcherParam param, + std::vector> templates, + std::string name = ""); - void set_param(FeatureMatcherParam param) { param_ = std::move(param); } + const ResultsVec& all_results() const& { return all_results_; } - std::pair analyze() const; + ResultsVec&& all_results() && { return std::move(all_results_); } + + const ResultsVec& filtered_results() const& { return filtered_results_; } + + ResultsVec filtered_results() && { return std::move(filtered_results_); } private: - ResultsVec foreach_rois(const cv::Mat& templ) const; - ResultsVec match_roi( + void analyze(); + ResultsVec match_all_rois(const cv::Mat& templ); + ResultsVec feature_match( + const cv::Mat& templ, const std::vector& keypoints_1, const cv::Mat& descriptors_1, - const cv::Rect& roi_2) const; + const cv::Rect& roi_2); + void add_results(ResultsVec results, int count); + void sort(); + +private: cv::Ptr create_detector() const; std::pair, cv::Mat> detect(const cv::Mat& image, const cv::Mat& mask) const; @@ -50,25 +64,31 @@ class FeatureMatcher : public VisionBase std::vector> match(const cv::Mat& descriptors_1, const cv::Mat& descriptors_2) const; - ResultsVec postproc( + ResultsVec feature_postproc( const std::vector>& match_points, const std::vector& keypoints_1, const std::vector& keypoints_2, - const cv::Rect& roi_2) const; + int templ_cols, + int templ_rows, + std::vector& good_matches) const; - void draw_result( + cv::Mat draw_result( const cv::Mat& templ, const std::vector& keypoints_1, const cv::Rect& roi, const std::vector& keypoints_2, const std::vector& good_matches, - ResultsVec& results) const; - void filter(ResultsVec& results, int count) const; - void sort(ResultsVec& results) const; - size_t preferred_index(const ResultsVec& results) const; + const ResultsVec& results) const; - FeatureMatcherParam param_; - std::shared_ptr template_; + void sort_(ResultsVec& results) const; + +private: + const FeatureMatcherParam param_; + const std::vector> templates_; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp b/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp index a237199bc..05928253b 100644 --- a/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp +++ b/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp @@ -8,56 +8,63 @@ MAA_VISION_NS_BEGIN -std::pair NeuralNetworkClassifier::analyze() const +NeuralNetworkClassifier::NeuralNetworkClassifier( + cv::Mat image, + NeuralNetworkClassifierParam param, + std::shared_ptr session, + std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) + , session_(std::move(session)) +{ + analyze(); +} + +void NeuralNetworkClassifier::analyze() { LogFunc << name_; if (!session_) { LogError << "OrtSession not loaded"; - return {}; + return; } if (param_.cls_size == 0) { LogError << "cls_size == 0"; - return {}; + return; } if (param_.cls_size != param_.labels.size()) { LogError << "cls_size != labels.size()" << VAR(param_.cls_size) << VAR(param_.labels.size()); - return {}; + return; } auto start_time = std::chrono::steady_clock::now(); - ResultsVec results = foreach_rois(); - auto cost = duration_since(start_time); - LogTrace << name_ << "Raw:" << VAR(results) << VAR(cost); - const auto& expected = param_.expected; - filter(results, expected); + auto results = classify_all_rois(); + add_results(std::move(results), param_.expected); - cost = duration_since(start_time); - LogTrace << name_ << "Filter:" << VAR(results) << VAR(expected) << VAR(cost); + sort(); - sort(results); - size_t index = preferred_index(results); - return { results, index }; + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::foreach_rois() const +NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::classify_all_rois() { if (param_.roi.empty()) { return { classify(cv::Rect(0, 0, image_.cols, image_.rows)) }; } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - Result res = classify(roi); - results.emplace_back(std::move(res)); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + Result res = classify(roi); + results.emplace_back(std::move(res)); + } + return results; } - - return results; } -NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect& roi) const +NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect& roi) { if (!session_) { LogError << "OrtSession not loaded"; @@ -114,17 +121,33 @@ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect result.label = param_.labels[result.cls_index]; result.box = roi; - draw_result(result); + if (debug_draw_) { + auto draw = draw_result(result); + handle_draw(draw); + } return result; } -void NeuralNetworkClassifier::draw_result(const Result& res) const +void NeuralNetworkClassifier::add_results(ResultsVec results, const std::vector& expected) { - if (!debug_draw_) { - return; - } + std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { + return std::ranges::find(expected, res.cls_index) != expected.end(); + }); + + merge_vector_(all_results_, std::move(results)); +} + +void NeuralNetworkClassifier::sort() +{ + sort_(all_results_); + sort_(filtered_results_); + + handle_index(filtered_results_.size(), param_.result_index); +} +cv::Mat NeuralNetworkClassifier::draw_result(const Result& res) const +{ cv::Mat image_draw = draw_roi(res.box); cv::Point pt(res.box.x + res.box.width + 5, res.box.y + 20); @@ -140,21 +163,10 @@ void NeuralNetworkClassifier::draw_result(const Result& res) const pt.y += 20; } - handle_draw(image_draw); -} - -void NeuralNetworkClassifier::filter(ResultsVec& results, const std::vector& expected) const -{ - if (expected.empty()) { - return; - } - - std::erase_if(results, [&](const Result& res) { - return std::find(expected.begin(), expected.end(), res.cls_index) == expected.end(); - }); + return image_draw; } -void NeuralNetworkClassifier::sort(ResultsVec& results) const +void NeuralNetworkClassifier::sort_(ResultsVec& results) const { switch (param_.order_by) { case ResultOrderBy::Horizontal: @@ -178,14 +190,4 @@ void NeuralNetworkClassifier::sort(ResultsVec& results) const } } -size_t NeuralNetworkClassifier::preferred_index(const ResultsVec& results) const -{ - auto index_opt = pythonic_index(results.size(), param_.result_index); - if (!index_opt) { - return SIZE_MAX; - } - - return *index_opt; -} - MAA_VISION_NS_END \ No newline at end of file diff --git a/source/MaaFramework/Vision/NeuralNetworkClassifier.h b/source/MaaFramework/Vision/NeuralNetworkClassifier.h index 7d3153b1f..63f97dc26 100644 --- a/source/MaaFramework/Vision/NeuralNetworkClassifier.h +++ b/source/MaaFramework/Vision/NeuralNetworkClassifier.h @@ -1,4 +1,5 @@ #pragma once +#pragma once #include #include @@ -29,23 +30,40 @@ class NeuralNetworkClassifier : public VisionBase using ResultsVec = std::vector; public: - void set_param(NeuralNetworkClassifierParam param) { param_ = std::move(param); } + NeuralNetworkClassifier( + cv::Mat image, + NeuralNetworkClassifierParam param, + std::shared_ptr session, + std::string name = ""); + + const ResultsVec& all_results() const& { return all_results_; } + + ResultsVec&& all_results() && { return std::move(all_results_); } - void set_session(std::shared_ptr session) { session_ = std::move(session); } + const ResultsVec& filtered_results() const& { return filtered_results_; } - std::pair analyze() const; + ResultsVec filtered_results() && { return std::move(filtered_results_); } private: - ResultsVec foreach_rois() const; - Result classify(const cv::Rect& roi) const; - void draw_result(const Result& res) const; + void analyze(); - void filter(ResultsVec& results, const std::vector& expected) const; - void sort(ResultsVec& results) const; - size_t preferred_index(const ResultsVec& results) const; + ResultsVec classify_all_rois(); + Result classify(const cv::Rect& roi); - NeuralNetworkClassifierParam param_; + void add_results(ResultsVec results, const std::vector& expected); + void sort(); + +private: + cv::Mat draw_result(const Result& res) const; + void sort_(ResultsVec& results) const; + +private: + const NeuralNetworkClassifierParam param_; std::shared_ptr session_ = nullptr; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/NeuralNetworkDetector.cpp b/source/MaaFramework/Vision/NeuralNetworkDetector.cpp index aa0f6a620..e97428a90 100644 --- a/source/MaaFramework/Vision/NeuralNetworkDetector.cpp +++ b/source/MaaFramework/Vision/NeuralNetworkDetector.cpp @@ -9,59 +9,63 @@ MAA_VISION_NS_BEGIN -std::pair NeuralNetworkDetector::analyze() const +NeuralNetworkDetector::NeuralNetworkDetector( + cv::Mat image, + NeuralNetworkDetectorParam param, + std::shared_ptr session, + std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) + , session_(std::move(session)) +{ + analyze(); +} + +void NeuralNetworkDetector::analyze() { LogFunc << name_; if (!session_) { LogError << "OrtSession not loaded"; - return {}; + return; } if (param_.cls_size == 0) { LogError << "cls_size == 0"; - return {}; + return; } if (param_.cls_size != param_.labels.size()) { LogError << "cls_size != labels.size()" << VAR(param_.cls_size) << VAR(param_.labels.size()); - return {}; + return; } auto start_time = std::chrono::steady_clock::now(); - ResultsVec results = foreach_rois(); - auto cost = duration_since(start_time); - LogTrace << name_ << "Raw:" << VAR(results) << VAR(cost); - const auto& expected = param_.expected; - filter(results, expected); + auto results = detect_all_rois(); + add_results(std::move(results), param_.expected); - cost = duration_since(start_time); - LogTrace << name_ << "Filter:" << VAR(results) << VAR(expected) << VAR(cost); + sort(); - sort(results); - size_t index = preferred_index(results); - return { results, index }; + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::foreach_rois() const +NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect_all_rois() { if (param_.roi.empty()) { return detect(cv::Rect(0, 0, image_.cols, image_.rows)); } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - ResultsVec res = detect(roi); - results.insert( - results.end(), - std::make_move_iterator(res.begin()), - std::make_move_iterator(res.end())); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + auto res = detect(roi); + merge_vector_(results, std::move(res)); + } + return results; } - - return results; } -NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect(const cv::Rect& roi) const +NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect(const cv::Rect& roi) { if (!session_) { LogError << "OrtSession not loaded"; @@ -157,28 +161,33 @@ NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect(const cv::Rect& std::make_move_iterator(nms_results.end())); } - draw_result(roi, all_nms_results); + if (debug_draw_) { + auto draw = draw_result(roi, all_nms_results); + handle_draw(draw); + } return all_nms_results; } -void NeuralNetworkDetector::filter(ResultsVec& results, const std::vector& expected) const +void NeuralNetworkDetector::add_results(ResultsVec results, const std::vector& expected) { - if (expected.empty()) { - return; - } - - std::erase_if(results, [&](const Result& res) { - return std::find(expected.begin(), expected.end(), res.cls_index) == expected.end(); + std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { + return std::ranges::find(expected, res.cls_index) != expected.end(); }); + + merge_vector_(all_results_, std::move(results)); } -void NeuralNetworkDetector::draw_result(const cv::Rect& roi, const ResultsVec& results) const +void NeuralNetworkDetector::sort() { - if (!debug_draw_) { - return; - } + sort_(all_results_); + sort_(filtered_results_); + + handle_index(filtered_results_.size(), param_.result_index); +} +cv::Mat NeuralNetworkDetector::draw_result(const cv::Rect& roi, const ResultsVec& results) const +{ cv::Mat image_draw = draw_roi(roi); for (const Result& res : results) { @@ -205,10 +214,10 @@ void NeuralNetworkDetector::draw_result(const cv::Rect& roi, const ResultsVec& r 1); } - handle_draw(image_draw); + return image_draw; } -void NeuralNetworkDetector::sort(ResultsVec& results) const +void NeuralNetworkDetector::sort_(ResultsVec& results) const { switch (param_.order_by) { case ResultOrderBy::Horizontal: @@ -232,14 +241,4 @@ void NeuralNetworkDetector::sort(ResultsVec& results) const } } -size_t NeuralNetworkDetector::preferred_index(const ResultsVec& results) const -{ - auto index_opt = pythonic_index(results.size(), param_.result_index); - if (!index_opt) { - return SIZE_MAX; - } - - return *index_opt; -} - MAA_VISION_NS_END \ No newline at end of file diff --git a/source/MaaFramework/Vision/NeuralNetworkDetector.h b/source/MaaFramework/Vision/NeuralNetworkDetector.h index a2b3f049f..37d030bae 100644 --- a/source/MaaFramework/Vision/NeuralNetworkDetector.h +++ b/source/MaaFramework/Vision/NeuralNetworkDetector.h @@ -27,23 +27,40 @@ class NeuralNetworkDetector : public VisionBase using ResultsVec = std::vector; public: - void set_session(std::shared_ptr session) { session_ = std::move(session); } + NeuralNetworkDetector( + cv::Mat image, + NeuralNetworkDetectorParam param, + std::shared_ptr session, + std::string name = ""); - void set_param(NeuralNetworkDetectorParam param) { param_ = std::move(param); } + const ResultsVec& all_results() const& { return all_results_; } - std::pair analyze() const; + ResultsVec&& all_results() && { return std::move(all_results_); } + + const ResultsVec& filtered_results() const& { return filtered_results_; } + + ResultsVec filtered_results() && { return std::move(filtered_results_); } private: - ResultsVec foreach_rois() const; - ResultsVec detect(const cv::Rect& roi) const; - void draw_result(const cv::Rect& roi, const ResultsVec& results) const; + void analyze(); + + ResultsVec detect_all_rois(); + ResultsVec detect(const cv::Rect& roi); + + void add_results(ResultsVec results, const std::vector& expected); + void sort(); - void filter(ResultsVec& results, const std::vector& expected) const; - void sort(ResultsVec& results) const; - size_t preferred_index(const ResultsVec& results) const; +private: + cv::Mat draw_result(const cv::Rect& roi, const ResultsVec& results) const; + void sort_(ResultsVec& results) const; - NeuralNetworkDetectorParam param_; +private: + const NeuralNetworkDetectorParam param_; std::shared_ptr session_ = nullptr; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/OCRer.cpp b/source/MaaFramework/Vision/OCRer.cpp index 8b3d62a7c..944281d00 100644 --- a/source/MaaFramework/Vision/OCRer.cpp +++ b/source/MaaFramework/Vision/OCRer.cpp @@ -10,63 +10,73 @@ MAA_VISION_NS_BEGIN -std::pair OCRer::analyze() const +OCRer::OCRer( + cv::Mat image, + OCRerParam param, + std::shared_ptr deter, + std::shared_ptr recer, + std::shared_ptr ocrer, + InstanceStatus* status, + std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) + , deter_(std::move(deter)) + , recer_(std::move(recer)) + , ocrer_(std::move(ocrer)) + , status_(status) { - auto start_time = std::chrono::steady_clock::now(); - - ResultsVec results = foreach_rois(); - - auto cost = duration_since(start_time); - LogTrace << name_ << "Raw:" << VAR(results) << VAR(param_.model) << VAR(cost); + analyze(); +} - const auto& expected = param_.text; - postproc_and_filter(results, expected); +void OCRer::analyze() +{ + auto start_time = std::chrono::steady_clock::now(); - cost = duration_since(start_time); - LogTrace << name_ << "Proc:" << VAR(results) << VAR(expected) << VAR(param_.model) << VAR(cost); + auto results = predict_all_rois(); + add_results(std::move(results), param_.text); - sort(results); - size_t index = preferred_index(results); + sort(); - return { results, index }; + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -OCRer::ResultsVec OCRer::foreach_rois() const +OCRer::ResultsVec OCRer::predict_all_rois() { if (param_.roi.empty()) { - cv::Rect roi(0, 0, image_.cols, image_.rows); - return predict(roi); + return predict(cv::Rect(0, 0, image_.cols, image_.rows)); } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - ResultsVec res = predict(roi); - results.insert( - results.end(), - std::make_move_iterator(res.begin()), - std::make_move_iterator(res.end())); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + auto res = predict(roi); + merge_vector_(results, std::move(res)); + } + return results; } - return results; } -OCRer::ResultsVec OCRer::predict(const cv::Rect& roi) const +OCRer::ResultsVec OCRer::predict(const cv::Rect& roi) { auto image_roi = image_with_roi(roi); - if (!status_) { - LogError << "status_ is null"; - return {}; - } - ResultsVec results; - if (auto results_opt = status_->get_ocr_cache(image_roi)) { - LogTrace << "Hit OCR cache" << VAR(roi); - results = std::any_cast(*std::move(results_opt)); + bool hit_cache = false; + + if (status_) { + if (auto results_opt = status_->get_ocr_cache(image_roi)) { + LogTrace << "Hit OCR cache" << VAR(roi); + hit_cache = true; + results = std::any_cast(*std::move(results_opt)); + } } - else { + + if (!hit_cache) { results = param_.only_rec ? ResultsVec { predict_only_rec(image_roi) } : predict_det_and_rec(image_roi); - status_->set_ocr_cache(image_roi, results); + if (status_) { + status_->set_ocr_cache(image_roi, results); + } } std::ranges::for_each(results, [&](auto& res) { @@ -74,7 +84,11 @@ OCRer::ResultsVec OCRer::predict(const cv::Rect& roi) const res.box.y += roi.y; }); - draw_result(roi, results); + if (debug_draw_) { + auto draw = draw_result(roi, results); + handle_draw(draw); + } + return results; } @@ -158,12 +172,8 @@ OCRer::Result OCRer::predict_only_rec(const cv::Mat& image_roi) const return result; } -void OCRer::draw_result(const cv::Rect& roi, const ResultsVec& results) const +cv::Mat OCRer::draw_result(const cv::Rect& roi, const ResultsVec& results) const { - if (!debug_draw_) { - return; - } - cv::Mat image_draw = draw_roi(roi); for (size_t i = 0; i != results.size(); ++i) { @@ -183,25 +193,32 @@ void OCRer::draw_result(const cv::Rect& roi, const ResultsVec& results) const 1); } - handle_draw(image_draw); + return image_draw; } -void OCRer::postproc_and_filter(ResultsVec& results, const std::vector& expected) - const +void OCRer::add_results(ResultsVec results, const std::vector& expected) { - for (auto iter = results.begin(); iter != results.end();) { - auto& res = *iter; - + auto copied = results; + for (auto& res : copied) { postproc_trim_(res); postproc_replace_(res); if (!filter_by_required(res, expected)) { - iter = results.erase(iter); continue; } - ++iter; + filtered_results_.emplace_back(std::move(res)); } + + merge_vector_(all_results_, std::move(results)); +} + +void OCRer::sort() +{ + sort_(all_results_); + sort_(filtered_results_); + + handle_index(filtered_results_.size(), param_.result_index); } void OCRer::postproc_trim_(Result& res) const @@ -231,7 +248,7 @@ bool OCRer::filter_by_required(const Result& res, const std::vector; public: - void set_session( + OCRer( + cv::Mat image, + OCRerParam param, std::shared_ptr deter, std::shared_ptr recer, - std::shared_ptr ocrer) - { - deter_ = std::move(deter); - recer_ = std::move(recer); - ocrer_ = std::move(ocrer); - } + std::shared_ptr ocrer, + InstanceStatus* status = nullptr, + std::string name = ""); + + const ResultsVec& all_results() const& { return all_results_; } + + ResultsVec&& all_results() && { return std::move(all_results_); } - void set_status(InstanceStatus* status) { status_ = status; } + const ResultsVec& filtered_results() const& { return filtered_results_; } + + ResultsVec filtered_results() && { return std::move(filtered_results_); } + +private: + void analyze(); - void set_param(OCRerParam param) { param_ = std::move(param); } + ResultsVec predict_all_rois(); + ResultsVec predict(const cv::Rect& roi); - std::pair analyze() const; + void add_results(ResultsVec results, const std::vector& expected); + void sort(); private: - ResultsVec foreach_rois() const; - ResultsVec predict(const cv::Rect& roi) const; ResultsVec predict_det_and_rec(const cv::Mat& image_roi) const; Result predict_only_rec(const cv::Mat& image_roi) const; - void draw_result(const cv::Rect& roi, const ResultsVec& results) const; - void postproc_and_filter(ResultsVec& results, const std::vector& expected) const; + cv::Mat draw_result(const cv::Rect& roi, const ResultsVec& results) const; + void postproc_trim_(Result& res) const; void postproc_replace_(Result& res) const; bool filter_by_required(const Result& res, const std::vector& expected) const; - void sort(ResultsVec& results) const; - size_t preferred_index(const ResultsVec& results) const; + void sort_(ResultsVec& results) const; + +private: + const OCRerParam param_; - OCRerParam param_; std::shared_ptr deter_ = nullptr; std::shared_ptr recer_ = nullptr; std::shared_ptr ocrer_ = nullptr; InstanceStatus* status_ = nullptr; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/TemplateComparator.cpp b/source/MaaFramework/Vision/TemplateComparator.cpp index 1c1a7e65c..f50bec164 100644 --- a/source/MaaFramework/Vision/TemplateComparator.cpp +++ b/source/MaaFramework/Vision/TemplateComparator.cpp @@ -6,55 +6,62 @@ MAA_VISION_NS_BEGIN -TemplateComparator::ResultsVec - TemplateComparator::analyze(const cv::Mat& lhs, const cv::Mat& rhs) const +TemplateComparator::TemplateComparator( + cv::Mat lhs, + cv::Mat rhs, + TemplateComparatorParam param, + std::string name) + : VisionBase(std::move(lhs), std::move(name)) + , rhs_image_(std::move(rhs)) + , param_(std::move(param)) { - if (lhs.size() != rhs.size()) { - LogError << "lhs.size() != rhs.size()" << VAR(lhs) << VAR(rhs); - return {}; + analyze(); +} + +void TemplateComparator::analyze() +{ + if (image_.size() != rhs_image_.size()) { + LogError << "lhs_image_.size() != rhs_image_.size()" << VAR(image_) << VAR(rhs_image_); + return; } auto start_time = std::chrono::steady_clock::now(); - ResultsVec results = foreach_rois(lhs, rhs); + auto results = compare_all_rois(); + add_results(std::move(results), param_.threshold); auto cost = duration_since(start_time); - LogTrace << "Raw:" << VAR(results) << VAR(cost); - - double threshold = param_.threshold; - filter(results, threshold); - - cost = duration_since(start_time); - LogTrace << "Proc:" << VAR(results) << VAR(threshold) << VAR(cost); - return results; + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -TemplateComparator::ResultsVec - TemplateComparator::foreach_rois(const cv::Mat& lhs, const cv::Mat& rhs) const +TemplateComparator::ResultsVec TemplateComparator::compare_all_rois() { auto method = param_.method; if (param_.roi.empty()) { - double score = comp(lhs, rhs, method); - return { Result { .box = cv::Rect(0, 0, lhs.cols, lhs.rows), .score = score } }; + double score = comp(image_, rhs_image_, method); + return { Result { .box = cv::Rect(0, 0, image_.cols, image_.rows), .score = score } }; } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - cv::Mat lhs_roi = lhs(correct_roi(roi, lhs)); - cv::Mat rhs_roi = rhs(correct_roi(roi, rhs)); - - double score = comp(lhs_roi, rhs_roi, method); - Result res { .box = roi, .score = score }; - results.emplace_back(std::move(res)); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + cv::Mat lhs_roi = image_(correct_roi(roi, image_)); + cv::Mat rhs_roi = rhs_image_(correct_roi(roi, rhs_image_)); + + double score = comp(lhs_roi, rhs_roi, method); + results.emplace_back(Result { .box = roi, .score = score }); + } + return results; } - - return results; } -void TemplateComparator::filter(ResultsVec& results, double threshold) const +void TemplateComparator::add_results(ResultsVec results, double threshold) { - std::erase_if(results, [threshold](const auto& res) { return res.score < threshold; }); + std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { + return res.score > threshold; + }); + + merge_vector_(all_results_, std::move(results)); } double TemplateComparator::comp(const cv::Mat& lhs, const cv::Mat& rhs, int method) diff --git a/source/MaaFramework/Vision/TemplateComparator.h b/source/MaaFramework/Vision/TemplateComparator.h index 957e5426a..1366dd2c9 100644 --- a/source/MaaFramework/Vision/TemplateComparator.h +++ b/source/MaaFramework/Vision/TemplateComparator.h @@ -1,11 +1,12 @@ #pragma once #include "Utils/JsonExt.hpp" +#include "VisionBase.h" #include "VisionTypes.h" MAA_VISION_NS_BEGIN -class TemplateComparator +class TemplateComparator : public VisionBase { public: struct Result @@ -19,19 +20,35 @@ class TemplateComparator using ResultsVec = std::vector; public: - TemplateComparator() = default; + TemplateComparator( + cv::Mat lhs, + cv::Mat rhs, + TemplateComparatorParam param, + std::string name = ""); - void set_param(TemplateComparatorParam param) { param_ = std::move(param); } + const ResultsVec& all_results() const& { return all_results_; } - ResultsVec analyze(const cv::Mat& lhs, const cv::Mat& rhs) const; + ResultsVec&& all_results() && { return std::move(all_results_); } + + const ResultsVec& filtered_results() const& { return filtered_results_; } + + ResultsVec filtered_results() && { return std::move(filtered_results_); } private: - ResultsVec foreach_rois(const cv::Mat& lhs, const cv::Mat& rhs) const; - void filter(ResultsVec& results, double threshold) const; + void analyze(); + ResultsVec compare_all_rois(); + + void add_results(ResultsVec results, double threshold); static double comp(const cv::Mat& lhs, const cv::Mat& rhs, int method); - TemplateComparatorParam param_; +private: + const cv::Mat rhs_image_ = {}; + const TemplateComparatorParam param_; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/TemplateMatcher.cpp b/source/MaaFramework/Vision/TemplateMatcher.cpp index 8e3140a3f..b8310bda8 100644 --- a/source/MaaFramework/Vision/TemplateMatcher.cpp +++ b/source/MaaFramework/Vision/TemplateMatcher.cpp @@ -7,56 +7,50 @@ MAA_VISION_NS_BEGIN -std::pair TemplateMatcher::analyze() const +TemplateMatcher::TemplateMatcher( + cv::Mat image, + TemplateMatcherParam param, + std::vector> templates, + std::string name) + : VisionBase(std::move(image), std::move(name)) + , param_(std::move(param)) + , templates_(std::move(templates)) +{ + analyze(); +} + +void TemplateMatcher::analyze() { if (templates_.empty()) { LogError << name_ << "templates is empty" << VAR(param_.template_paths); - return {}; + return; } if (templates_.size() != param_.thresholds.size()) { LogError << name_ << "templates.size() != thresholds.size()" << VAR(templates_.size()) << VAR(param_.thresholds.size()); - return {}; + return; } - ResultsVec all_results; + auto start_time = std::chrono::steady_clock::now(); + for (size_t i = 0; i != templates_.size(); ++i) { - const auto& image_ptr = templates_.at(i); - if (!image_ptr) { - LogWarn << name_ << "template is empty" << VAR(param_.template_paths.at(i)) - << VAR(image_ptr); + const auto& templ = templates_.at(i); + if (!templ) { continue; } - const cv::Mat& templ = *image_ptr; - - auto start_time = std::chrono::steady_clock::now(); - ResultsVec results = foreach_rois(templ); - - auto cost = duration_since(start_time); - const std::string& path = param_.template_paths.at(i); - LogTrace << name_ << "Raw:" << VAR(results) << VAR(path) << VAR(cost); - - double threshold = param_.thresholds.at(i); - filter(results, threshold); - - cost = duration_since(start_time); - LogTrace << name_ << "Filter:" << VAR(results) << VAR(path) << VAR(threshold) << VAR(cost); - - all_results.insert( - all_results.end(), - std::make_move_iterator(results.begin()), - std::make_move_iterator(results.end())); + auto results = match_all_rois(*templ); + add_results(std::move(results), param_.thresholds.at(i)); } - sort(all_results); - size_t index = preferred_index(all_results); + sort(); - return { all_results, index }; + auto cost = duration_since(start_time); + LogTrace << name_ << VAR(all_results_) << VAR(filtered_results_) << VAR(cost); } -TemplateMatcher::ResultsVec TemplateMatcher::foreach_rois(const cv::Mat& templ) const +TemplateMatcher::ResultsVec TemplateMatcher::match_all_rois(const cv::Mat& templ) { if (templ.empty()) { LogWarn << name_ << "template is empty" << VAR(param_.template_paths) << VAR(templ.size()); @@ -64,22 +58,20 @@ TemplateMatcher::ResultsVec TemplateMatcher::foreach_rois(const cv::Mat& templ) } if (param_.roi.empty()) { - return match(cv::Rect(0, 0, image_.cols, image_.rows), templ); + return template_match(cv::Rect(0, 0, image_.cols, image_.rows), templ); } - - ResultsVec results; - for (const cv::Rect& roi : param_.roi) { - ResultsVec res = match(roi, templ); - results.insert( - results.end(), - std::make_move_iterator(res.begin()), - std::make_move_iterator(res.end())); + else { + ResultsVec results; + for (const cv::Rect& roi : param_.roi) { + auto res = template_match(roi, templ); + merge_vector_(results, std::move(res)); + } + return results; } - - return results; } -TemplateMatcher::ResultsVec TemplateMatcher::match(const cv::Rect& roi, const cv::Mat& templ) const +TemplateMatcher::ResultsVec + TemplateMatcher::template_match(const cv::Rect& roi, const cv::Mat& templ) { cv::Mat image = image_with_roi(roi); @@ -121,20 +113,20 @@ TemplateMatcher::ResultsVec TemplateMatcher::match(const cv::Rect& roi, const cv } auto nms_results = NMS(std::move(raw_results)); - draw_result(roi, templ, nms_results); + + if (debug_draw_) { + auto draw = draw_result(roi, templ, nms_results); + handle_draw(draw); + } return nms_results; } -void TemplateMatcher::draw_result( +cv::Mat TemplateMatcher::draw_result( const cv::Rect& roi, const cv::Mat& templ, const ResultsVec& results) const { - if (!debug_draw_) { - return; - } - cv::Mat image_draw = draw_roi(roi); const auto color = cv::Scalar(0, 0, 255); @@ -177,15 +169,27 @@ void TemplateMatcher::draw_result( cv::line(image_draw, cv::Point(raw_width, 0), results.front().box.tl(), color, 1); } - handle_draw(image_draw); + return image_draw; } -void TemplateMatcher::filter(ResultsVec& results, double threshold) const +void TemplateMatcher::add_results(ResultsVec results, double threshold) { - std::erase_if(results, [threshold](const auto& res) { return res.score < threshold; }); + std::ranges::copy_if(results, std::back_inserter(filtered_results_), [&](const auto& res) { + return res.score > threshold; + }); + + merge_vector_(all_results_, std::move(results)); } -void TemplateMatcher::sort(ResultsVec& results) const +void TemplateMatcher::sort() +{ + sort_(all_results_); + sort_(filtered_results_); + + handle_index(filtered_results_.size(), param_.result_index); +} + +void TemplateMatcher::sort_(ResultsVec& results) const { switch (param_.order_by) { case ResultOrderBy::Horizontal: @@ -209,14 +213,4 @@ void TemplateMatcher::sort(ResultsVec& results) const } } -size_t TemplateMatcher::preferred_index(const ResultsVec& results) const -{ - auto index_opt = pythonic_index(results.size(), param_.result_index); - if (!index_opt) { - return SIZE_MAX; - } - - return *index_opt; -} - MAA_VISION_NS_END \ No newline at end of file diff --git a/source/MaaFramework/Vision/TemplateMatcher.h b/source/MaaFramework/Vision/TemplateMatcher.h index d5da69396..9742d7060 100644 --- a/source/MaaFramework/Vision/TemplateMatcher.h +++ b/source/MaaFramework/Vision/TemplateMatcher.h @@ -23,26 +23,40 @@ class TemplateMatcher : public VisionBase using ResultsVec = std::vector; public: - void set_templates(std::vector> templates) - { - templates_ = std::move(templates); - } + TemplateMatcher( + cv::Mat image, + TemplateMatcherParam param, + std::vector> templates, + std::string name = ""); + + const ResultsVec& all_results() const& { return all_results_; } + + ResultsVec&& all_results() && { return std::move(all_results_); } - void set_param(TemplateMatcherParam param) { param_ = std::move(param); } + const ResultsVec& filtered_results() const& { return filtered_results_; } - std::pair analyze() const; + ResultsVec filtered_results() && { return std::move(filtered_results_); } private: - ResultsVec foreach_rois(const cv::Mat& templ) const; - ResultsVec match(const cv::Rect& roi, const cv::Mat& templ) const; - void draw_result(const cv::Rect& roi, const cv::Mat& templ, const ResultsVec& results) const; + void analyze(); + ResultsVec match_all_rois(const cv::Mat& templ); + ResultsVec template_match(const cv::Rect& roi, const cv::Mat& templ); - void filter(ResultsVec& results, double threshold) const; - void sort(ResultsVec& results) const; - size_t preferred_index(const ResultsVec& results) const; + void add_results(ResultsVec results, double threshold); + void sort(); + +private: + cv::Mat draw_result(const cv::Rect& roi, const cv::Mat& templ, const ResultsVec& results) const; - TemplateMatcherParam param_; - std::vector> templates_; + void sort_(ResultsVec& results) const; + +private: + const TemplateMatcherParam param_; + const std::vector> templates_; + +private: + ResultsVec all_results_; + ResultsVec filtered_results_; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/VisionBase.cpp b/source/MaaFramework/Vision/VisionBase.cpp index db4e90d7d..e2209dcfa 100644 --- a/source/MaaFramework/Vision/VisionBase.cpp +++ b/source/MaaFramework/Vision/VisionBase.cpp @@ -11,17 +11,13 @@ MAA_VISION_NS_BEGIN -void VisionBase::set_image(const cv::Mat& image) +VisionBase::VisionBase(cv::Mat image, std::string name) + : image_(std::move(image)) + , name_(std::move(name)) { - image_ = image; init_debug_draw(); } -void VisionBase::set_name(std::string name) -{ - name_ = std::move(name); -} - cv::Mat VisionBase::image_with_roi(const cv::Rect& roi) const { cv::Rect roi_corrected = correct_roi(roi, image_); @@ -56,19 +52,33 @@ cv::Mat VisionBase::draw_roi(const cv::Rect& roi, const cv::Mat& base) const return image_draw; } -void VisionBase::handle_draw(const cv::Mat& draw) const +void VisionBase::handle_draw(const cv::Mat& draw) { + draws_.emplace_back(draw); + if (save_draw_) { - save_image(draw); + draw_paths_.emplace_back(save_image(draw)); } } -void VisionBase::save_image(const cv::Mat& image) const +void VisionBase::handle_index(size_t total, int index) +{ + auto index_opt = pythonic_index(total, index); + if (!index_opt) { + preferred_index_ = SIZE_MAX; + return; + } + + preferred_index_ = *index_opt; +} + +std::filesystem::path VisionBase::save_image(const cv::Mat& image) const { std::string filename = std::format("{}_{}.png", name_, format_now_for_filename()); auto filepath = GlobalOptionMgr::get_instance().log_dir() / "vision" / path(filename); MAA_NS::imwrite(filepath, image); LogDebug << "save image to" << filepath; + return filepath; } void VisionBase::init_debug_draw() diff --git a/source/MaaFramework/Vision/VisionBase.h b/source/MaaFramework/Vision/VisionBase.h index bbf09ca2e..c3819b188 100644 --- a/source/MaaFramework/Vision/VisionBase.h +++ b/source/MaaFramework/Vision/VisionBase.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "Conf/Conf.h" #include "Utils/JsonExt.hpp" #include "Utils/NoWarningCVMat.hpp" @@ -9,28 +11,41 @@ MAA_VISION_NS_BEGIN class VisionBase { public: - void set_image(const cv::Mat& image); - void set_name(std::string name); + VisionBase(cv::Mat image, std::string name); + + const std::vector& draws() const& { return draws_; } + + std::vector draws() && { return std::move(draws_); } + + const std::vector& draw_paths() const& { return draw_paths_; } + + std::vector draw_paths() && { return std::move(draw_paths_); } + + size_t preferred_index() const { return preferred_index_; } protected: cv::Mat image_with_roi(const cv::Rect& roi) const; protected: cv::Mat draw_roi(const cv::Rect& roi, const cv::Mat& base = cv::Mat()) const; - void handle_draw(const cv::Mat& draw) const; + void handle_draw(const cv::Mat& draw); + void handle_index(size_t total, int index); protected: - cv::Mat image_ {}; - std::string name_; + const cv::Mat image_; + const std::string name_; bool debug_draw_ = false; private: void init_debug_draw(); - void save_image(const cv::Mat& image) const; + std::filesystem::path save_image(const cv::Mat& image) const; private: bool save_draw_ = false; + std::vector draws_; + std::vector draw_paths_; + size_t preferred_index_ = SIZE_MAX; }; MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/VisionTypes.h b/source/MaaFramework/Vision/VisionTypes.h index 983811f3f..01b3d65f8 100644 --- a/source/MaaFramework/Vision/VisionTypes.h +++ b/source/MaaFramework/Vision/VisionTypes.h @@ -152,7 +152,7 @@ struct FeatureMatcherParam inline static constexpr int kDefaultCount = 4; std::vector roi; - std::string template_path; + std::vector template_paths; bool green_mask = false; Detector detector = kDefaultDetector; diff --git a/source/MaaFramework/Vision/VisionUtils.hpp b/source/MaaFramework/Vision/VisionUtils.hpp index 5d2bddb1c..3b4893e63 100644 --- a/source/MaaFramework/Vision/VisionUtils.hpp +++ b/source/MaaFramework/Vision/VisionUtils.hpp @@ -162,6 +162,15 @@ inline static T softmax(const T& input) return output; } +template +inline static void merge_vector_(ResultsVec& left, ResultsVec right) +{ + left.insert( + left.end(), + std::make_move_iterator(right.begin()), + std::make_move_iterator(right.end())); +} + inline static cv::Mat hwc_to_chw(const cv::Mat& src) { std::vector rgb_images;