From cf5647101ccb1218e592b6c83c16b1b2feb0d85f Mon Sep 17 00:00:00 2001 From: MistEO Date: Tue, 16 Jul 2024 00:22:57 +0800 Subject: [PATCH] =?UTF-8?q?perf=20&=20fix:=20NN=20=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=20cls=5Fsize=20=E5=AD=97=E6=AE=B5=E9=9C=80?= =?UTF-8?q?=E6=B1=82=EF=BC=8C=E8=87=AA=E9=80=82=E5=BA=94=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E5=B0=BA=E5=AF=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix https://github.com/MaaXYZ/MaaFramework/issues/281 --- docs/en_us/3.1-PipelineProtocol.md | 8 --- ...64\347\272\277\345\215\217\350\256\256.md" | 8 --- .../MaaFramework/Resource/PipelineResMgr.cpp | 22 ------ .../Vision/NeuralNetworkClassifier.cpp | 70 ++++++++----------- .../Vision/NeuralNetworkDetector.cpp | 54 +++++++------- source/MaaFramework/Vision/VisionTypes.h | 2 - tools/pipeline.schema.json | 7 -- 7 files changed, 60 insertions(+), 111 deletions(-) diff --git a/docs/en_us/3.1-PipelineProtocol.md b/docs/en_us/3.1-PipelineProtocol.md index d4f0fb84d..e284729d3 100644 --- a/docs/en_us/3.1-PipelineProtocol.md +++ b/docs/en_us/3.1-PipelineProtocol.md @@ -373,9 +373,6 @@ This task property requires additional fields: - `roi`: *array* | *list>* Same as `TemplateMatch`.`roi`. -- `cls_size`: *int* - The total number of categories. Required. - - `labels`: *list* Labels, meaning the names of each category. Optional. It only affects debugging images and logs. If not filled, it will be filled with "Unknown." @@ -400,7 +397,6 @@ For example, if you want to recognize whether a cat or a mouse appears in a **fi ```jsonc { - "cls_size": 3, "labels": ["Cat", "Dog", "Mouse"], "expected": [0, 2] } @@ -419,9 +415,6 @@ This task property requires additional fields: - `roi`: *array* | *list>* Same as `TemplateMatch`.`roi`. -- `cls_size`: *int* - The total number of categories. Required. - - `labels`: *list* Labels, meaning the names of each category. Optional. It only affects debugging images and logs. If not filled, it will be filled with "Unknown." @@ -450,7 +443,6 @@ For example, if you want to detect cats, dogs, and mice in an image and only cli ```jsonc { - "cls_size": 3, "labels": ["Cat", "Dog", "Mouse"], "expected": [0, 2] } diff --git "a/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" "b/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" index 5cca09897..a843b89f2 100644 --- "a/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" +++ "b/docs/zh_cn/3.1-\344\273\273\345\212\241\346\265\201\346\260\264\347\272\277\345\215\217\350\256\256.md" @@ -381,9 +381,6 @@ graph LR; - `roi`: *array* | *list>* 同 `TemplateMatch`.`roi` -- `cls_size`: *int* - 总分类数,必选。 - - `labels`: *list* 标注,即每个分类的名字。可选。 仅影响调试图片及日志等,若未填写则会填充 "Unknown"。 @@ -409,7 +406,6 @@ graph LR; ```jsonc { - "cls_size": 3, "labels": ["Cat", "Dog", "Mouse"], "expected": [0, 2] } @@ -428,9 +424,6 @@ graph LR; - `roi`: *array* | *list>* 同 `TemplateMatch`.`roi` -- `cls_size`: *int* - 总分类数,必选。 - - `labels`: *list* 标注,即每个分类的名字。可选。 仅影响调试图片及日志等,若未填写则会填充 "Unknown"。 @@ -460,7 +453,6 @@ graph LR; ```jsonc { - "cls_size": 3, "labels": ["Cat", "Dog", "Mouse"], "expected": [0, 2] } diff --git a/source/MaaFramework/Resource/PipelineResMgr.cpp b/source/MaaFramework/Resource/PipelineResMgr.cpp index 4665c9eed..7603ed225 100644 --- a/source/MaaFramework/Resource/PipelineResMgr.cpp +++ b/source/MaaFramework/Resource/PipelineResMgr.cpp @@ -858,21 +858,10 @@ bool PipelineResMgr::parse_nn_classifier_param( return false; } - if (!get_and_check_value(input, "cls_size", output.cls_size, default_value.cls_size)) { - LogError << "failed to get_and_check_value cls_size" << VAR(input); - return false; - } - if (!get_and_check_value_or_array(input, "labels", output.labels, default_value.labels)) { LogError << "failed to get_and_check_value_or_array labels" << VAR(input); return false; } - if (output.labels.size() < output.cls_size) { - LogDebug << "labels.size() < cls_size, fill 'Unknown'" << VAR(output.labels.size()) - << VAR(output.cls_size); - output.labels.resize(output.cls_size, "Unknown"); - } - if (!get_and_check_value(input, "model", output.model, default_value.model)) { LogError << "failed to get_and_check_value model" << VAR(input); return false; @@ -913,21 +902,10 @@ bool PipelineResMgr::parse_nn_detector_param( return false; } - if (!get_and_check_value(input, "cls_size", output.cls_size, default_value.cls_size)) { - LogError << "failed to get_and_check_value cls_size" << VAR(input); - return false; - } - if (!get_and_check_value_or_array(input, "labels", output.labels, default_value.labels)) { LogError << "failed to get_and_check_value_or_array labels" << VAR(input); return false; } - if (output.labels.size() < output.cls_size) { - LogDebug << "labels.size() < cls_size, fill 'Unknown'" << VAR(output.labels.size()) - << VAR(output.cls_size); - output.labels.resize(output.cls_size, "Unknown"); - } - if (!get_and_check_value(input, "model", output.model, default_value.model)) { LogError << "failed to get_and_check_value model" << VAR(input); return false; diff --git a/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp b/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp index 477334d4f..f11b7ff55 100644 --- a/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp +++ b/source/MaaFramework/Vision/NeuralNetworkClassifier.cpp @@ -28,16 +28,6 @@ void NeuralNetworkClassifier::analyze() LogError << "OrtSession not loaded"; return; } - if (param_.cls_size == 0) { - LogError << "cls_size == 0"; - return; - } - if (param_.cls_size != param_.labels.size()) { - LogError << "cls_size != labels.size()" << VAR(param_.cls_size) - << VAR(param_.labels.size()); - return; - } - auto start_time = std::chrono::steady_clock::now(); auto results = classify_all_rois(); @@ -71,14 +61,22 @@ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect LogError << "OrtSession not loaded"; return {}; } + // batch_size, channel, height, width + // for yolov8, input_shape is { 1, 3, 640, 640 } + const auto input_shape = session_->GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); + if (input_shape.size() != 4) { + LogError << "Input shape is not 4" << VAR(input_shape); + return {}; + } cv::Mat image = image_with_roi(roi); + cv::Size raw_roi_size(image.cols, image.rows); + cv::Size input_image_size(static_cast(input_shape[3]), static_cast(input_shape[2])); + cv::resize(image, image, input_image_size, 0, 0, cv::INTER_AREA); std::vector input = image_to_tensor(image); // TODO: GPU auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - constexpr int64_t kBatchSize = 1; - std::array input_shape { kBatchSize, image.channels(), image.cols, image.rows }; Ort::Value input_tensor = Ort::Value::CreateTensor( memory_info, @@ -87,16 +85,6 @@ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect input_shape.data(), input_shape.size()); - std::vector output; - output.resize(param_.cls_size); - std::array output_shape { kBatchSize, static_cast(param_.cls_size) }; - Ort::Value output_tensor = Ort::Value::CreateTensor( - memory_info, - output.data(), - output.size(), - output_shape.data(), - output_shape.size()); - Ort::AllocatorWithDefaultOptions allocator; const std::string in_0 = session_->GetInputNameAllocated(0, allocator).get(); const std::string out_0 = session_->GetOutputNameAllocated(0, allocator).get(); @@ -104,30 +92,34 @@ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect const std::vector output_names { out_0.c_str() }; Ort::RunOptions run_options; - session_->Run( + auto output_tensor = session_->Run( run_options, input_names.data(), &input_tensor, - 1, + input_names.size(), output_names.data(), - &output_tensor, - 1); - - Result result; - result.raw = std::move(output); - result.probs = softmax(result.raw); - result.cls_index = - std::max_element(result.probs.begin(), result.probs.end()) - result.probs.begin(); - result.score = result.probs[result.cls_index]; - result.label = param_.labels[result.cls_index]; - result.box = roi; + output_names.size()); + + const float* raw_output = output_tensor[0].GetTensorData(); + std::vector output( + raw_output, + raw_output + output_tensor[0].GetTensorTypeAndShapeInfo().GetElementCount()); + + Result res; + res.raw = std::move(output); + res.probs = softmax(res.raw); + res.cls_index = std::max_element(res.probs.begin(), res.probs.end()) - res.probs.begin(); + res.score = res.probs[res.cls_index]; + res.label = res.cls_index < param_.labels.size() ? param_.labels[res.cls_index] + : std::format("Unkonwn_{}", res.cls_index); + res.box = roi; if (debug_draw_) { - auto draw = draw_result(result); + auto draw = draw_result(res); handle_draw(draw); } - return result; + return res; } void NeuralNetworkClassifier::add_results(ResultsVec results, const std::vector& expected) @@ -154,7 +146,7 @@ cv::Mat NeuralNetworkClassifier::draw_result(const Result& res) const cv::Mat image_draw = draw_roi(res.box); cv::Point pt(res.box.x + res.box.width + 5, res.box.y + 20); - for (size_t i = 0; i != param_.cls_size; ++i) { + for (size_t i = 0; i != res.raw.size(); ++i) { const auto color = i == res.cls_index ? cv::Scalar(0, 0, 255) : cv::Scalar(255, 0, 0); std::string text = std::format( "{} {}: prob {:.3f}, raw {:.3f}", @@ -193,4 +185,4 @@ void NeuralNetworkClassifier::sort_(ResultsVec& results) const } } -MAA_VISION_NS_END \ No newline at end of file +MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/NeuralNetworkDetector.cpp b/source/MaaFramework/Vision/NeuralNetworkDetector.cpp index 8a1a96738..8159aa720 100644 --- a/source/MaaFramework/Vision/NeuralNetworkDetector.cpp +++ b/source/MaaFramework/Vision/NeuralNetworkDetector.cpp @@ -29,15 +29,6 @@ void NeuralNetworkDetector::analyze() LogError << "OrtSession not loaded"; return; } - if (param_.cls_size == 0) { - LogError << "cls_size == 0"; - return; - } - if (param_.cls_size != param_.labels.size()) { - LogError << "cls_size != labels.size()" << VAR(param_.cls_size) - << VAR(param_.labels.size()); - return; - } auto start_time = std::chrono::steady_clock::now(); @@ -73,13 +64,22 @@ NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect(const cv::Rect& return {}; } + // batch_size, channel, height, width + // for yolov8, input_shape is { 1, 3, 640, 640 } + const auto input_shape = session_->GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); + if (input_shape.size() != 4) { + LogError << "Input shape is not 4" << VAR(input_shape); + return {}; + } + cv::Mat image = image_with_roi(roi); + cv::Size raw_roi_size(image.cols, image.rows); + cv::Size input_image_size(static_cast(input_shape[3]), static_cast(input_shape[2])); + cv::resize(image, image, input_image_size, 0, 0, cv::INTER_AREA); std::vector input = image_to_tensor(image); // TODO: GPU - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - constexpr int64_t kBatchSize = 1; - std::array input_shape { kBatchSize, image.channels(), image.cols, image.rows }; + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); Ort::Value input_tensor = Ort::Value::CreateTensor( memory_info, @@ -124,12 +124,9 @@ NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect(const cv::Rect& raw_output + (i + 1) * output_shape[2]); } - ResultsVec all_nms_results; - const size_t output_size = output.back().size(); + ResultsVec raw_results; for (size_t i = 0; i < output_size; ++i) { - ResultsVec raw_results; - constexpr size_t kConfidenceIndex = 4; for (size_t j = kConfidenceIndex; j < output.size(); ++j) { float score = output[j][i]; @@ -149,25 +146,32 @@ NeuralNetworkDetector::ResultsVec NeuralNetworkDetector::detect(const cv::Rect& Result res; res.cls_index = j - kConfidenceIndex; - res.label = param_.labels[res.cls_index]; + res.label = res.cls_index < param_.labels.size() + ? param_.labels[res.cls_index] + : std::format("Unkonwn_{}", res.cls_index); res.box = box; res.score = score; raw_results.emplace_back(std::move(res)); } - auto nms_results = NMS(std::move(raw_results)); - all_nms_results.insert( - all_nms_results.end(), - std::make_move_iterator(nms_results.begin()), - std::make_move_iterator(nms_results.end())); + } + + auto nms_results = NMS(std::move(raw_results)); + + // post process + for (Result& res : nms_results) { + res.box.x = res.box.x * raw_roi_size.width / input_image_size.width + roi.x; + res.box.y = res.box.y * raw_roi_size.height / input_image_size.height + roi.y; + res.box.width = res.box.width * raw_roi_size.width / input_image_size.width; + res.box.height = res.box.height * raw_roi_size.height / input_image_size.height; } if (debug_draw_) { - auto draw = draw_result(roi, all_nms_results); + auto draw = draw_result(roi, nms_results); handle_draw(draw); } - return all_nms_results; + return nms_results; } void NeuralNetworkDetector::add_results(ResultsVec results, const std::vector& expected) @@ -244,4 +248,4 @@ void NeuralNetworkDetector::sort_(ResultsVec& results) const } } -MAA_VISION_NS_END \ No newline at end of file +MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/VisionTypes.h b/source/MaaFramework/Vision/VisionTypes.h index 3ee5a26da..78df5491d 100644 --- a/source/MaaFramework/Vision/VisionTypes.h +++ b/source/MaaFramework/Vision/VisionTypes.h @@ -70,7 +70,6 @@ struct CustomRecognizerParam struct NeuralNetworkClassifierParam { - size_t cls_size = 0; std::vector labels; // only for output and debug std::string model; @@ -90,7 +89,6 @@ struct NeuralNetworkDetectorParam inline static constexpr Net kDefaultNet = Net::YoloV8; inline static constexpr double kDefaultThreshold = 0.3; - size_t cls_size = 0; std::vector labels; // only for output and debug std::string model; Net net = kDefaultNet; diff --git a/tools/pipeline.schema.json b/tools/pipeline.schema.json index 0e91795c7..f1a913b2f 100644 --- a/tools/pipeline.schema.json +++ b/tools/pipeline.schema.json @@ -177,11 +177,6 @@ "default": false }, "model": {}, - "cls_size": { - "description": "总分类数,必选。", - "type": "integer", - "default": 2 - }, "labels": { "description": "标注,即每个分类的名字。可选。", "type": "array", @@ -409,7 +404,6 @@ } }, "required": [ - "cls_size", "model", "expected" ] @@ -447,7 +441,6 @@ } }, "required": [ - "cls_size", "model", "expected" ]