From b649f2d719473ce8d248aadb324a2cd863b1853c Mon Sep 17 00:00:00 2001 From: Nuzhny007 Date: Sun, 9 Jun 2024 08:47:57 +0300 Subject: [PATCH] Add YOLOv10 TensorRT detector --- README.md | 2 + example/examples.h | 3 +- src/Detector/OCVDNNDetector.cpp | 5 +- src/Detector/OCVDNNDetector.h | 3 +- src/Detector/YoloTensorRTDetector.cpp | 1 + src/Detector/tensorrt_yolo/YoloONNXv10_bb.hpp | 85 +++++++++++++++++++ src/Detector/tensorrt_yolo/class_detector.cpp | 8 +- src/Detector/tensorrt_yolo/class_detector.h | 3 +- src/Detector/tensorrt_yolo/ds_image.cpp | 4 +- 9 files changed, 106 insertions(+), 8 deletions(-) create mode 100644 src/Detector/tensorrt_yolo/YoloONNXv10_bb.hpp diff --git a/README.md b/README.md index bfec439c..8707027b 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ # Last changes +* YOLOv10 detector worked with TensorRT! Export pretrained Pytorch models [here (THU-MIG/yolov10)](https://github.com/THU-MIG/yolov10) to onnx format and run Multitarget-tracker with -e=6 example + * YOLOv9 detector worked with TensorRT! Export pretrained Pytorch models [here (WongKinYiu/yolov9)](https://github.com/WongKinYiu/yolov9) to onnx format and run Multitarget-tracker with -e=6 example * YOLOv8 instance segmentation models worked with TensorRT! Export pretrained Pytorch models [here (ultralytics/ultralytics)](https://github.com/ultralytics/ultralytics) to onnx format and run Multitarget-tracker with -e=6 example diff --git a/example/examples.h b/example/examples.h index 549c9dc2..62861622 100644 --- a/example/examples.h +++ b/example/examples.h @@ -650,7 +650,8 @@ class YoloTensorRTExample final : public VideoExample YOLOv7Mask, YOLOv8, YOLOv8Mask, - YOLOv9 + YOLOv9, + YOLOv10 }; YOLOModels usedModel = YOLOModels::YOLOv9; switch (usedModel) diff --git a/src/Detector/OCVDNNDetector.cpp b/src/Detector/OCVDNNDetector.cpp index 737b1227..0ba8dadc 100644 --- a/src/Detector/OCVDNNDetector.cpp +++ b/src/Detector/OCVDNNDetector.cpp @@ -140,6 +140,7 @@ bool OCVDNNDetector::Init(const config_t& config) dictNetType["YOLOV8"] = ModelType::YOLOV8; dictNetType["YOLOV8Mask"] = ModelType::YOLOV8Mask; dictNetType["YOLOV9"] = ModelType::YOLOV9; + dictNetType["YOLOV10"] = ModelType::YOLOV10; auto netType = dictNetType.find(net_type->second); if (netType != dictNetType.end()) @@ -346,7 +347,7 @@ void OCVDNNDetector::DetectInCrop(const cv::UMat& colorFrame, const cv::Rect& cr } else { - if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV5 || m_netType == ModelType::YOLOV9) + if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV5 || m_netType == ModelType::YOLOV9 || m_netType == ModelType::YOLOV10) { int rows = detections[0].size[1]; int dimensions = detections[0].size[2]; @@ -368,7 +369,7 @@ void OCVDNNDetector::DetectInCrop(const cv::UMat& colorFrame, const cv::Rect& cr for (int i = 0; i < rows; ++i) { - if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV9) + if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV9 || m_netType == ModelType::YOLOV10) { float* classes_scores = data + 4; diff --git a/src/Detector/OCVDNNDetector.h b/src/Detector/OCVDNNDetector.h index 6a014379..ee9331ee 100644 --- a/src/Detector/OCVDNNDetector.h +++ b/src/Detector/OCVDNNDetector.h @@ -40,7 +40,8 @@ class OCVDNNDetector final : public BaseDetector YOLOV7Mask, YOLOV8, YOLOV8Mask, - YOLOV9 + YOLOV9, + YOLOV10 }; cv::dnn::Net m_net; diff --git a/src/Detector/YoloTensorRTDetector.cpp b/src/Detector/YoloTensorRTDetector.cpp index 399772d3..43aab47e 100644 --- a/src/Detector/YoloTensorRTDetector.cpp +++ b/src/Detector/YoloTensorRTDetector.cpp @@ -103,6 +103,7 @@ bool YoloTensorRTDetector::Init(const config_t& config) dictNetType["YOLOV8"] = tensor_rt::YOLOV8; dictNetType["YOLOV8Mask"] = tensor_rt::YOLOV8Mask; dictNetType["YOLOV9"] = tensor_rt::YOLOV9; + dictNetType["YOLOV10"] = tensor_rt::YOLOV10; auto netType = dictNetType.find(net_type->second); if (netType != dictNetType.end()) diff --git a/src/Detector/tensorrt_yolo/YoloONNXv10_bb.hpp b/src/Detector/tensorrt_yolo/YoloONNXv10_bb.hpp new file mode 100644 index 00000000..1cf4bba5 --- /dev/null +++ b/src/Detector/tensorrt_yolo/YoloONNXv10_bb.hpp @@ -0,0 +1,85 @@ +#pragma once + +#include "YoloONNX.hpp" + +/// +/// \brief The YOLOv10_bb_onnx class +/// +class YOLOv10_bb_onnx : public YoloONNX +{ +protected: + /// + /// \brief GetResult + /// \param output + /// \return + /// + std::vector YoloONNX::GetResult(size_t imgIdx, int /*keep_topk*/, const std::vector& outputs, cv::Size frameSize) + { + std::vector resBoxes; + + //0: name: images, size: 1x3x640x640 + //1: name: output0, size: 1x300x6 + + const float fw = static_cast(frameSize.width) / static_cast(m_inputDims.d[3]); + const float fh = static_cast(frameSize.height) / static_cast(m_inputDims.d[2]); + + auto output = outputs[0]; + + size_t ncInd = 2; + size_t lenInd = 1; + size_t len = static_cast(m_outpuDims[0].d[lenInd]) / m_params.explicitBatchSize; + //auto Volume = [](const nvinfer1::Dims& d) + //{ + // return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); + //}; + auto volume = len * m_outpuDims[0].d[ncInd]; // Volume(m_outpuDims[0]); + output += volume * imgIdx; + //std::cout << "len = " << len << ", nc = " << nc << ", m_params.confThreshold = " << m_params.confThreshold << ", volume = " << volume << std::endl; + + std::vector classIds; + std::vector confidences; + std::vector rectBoxes; + classIds.reserve(len); + confidences.reserve(len); + rectBoxes.reserve(len); + + for (size_t i = 0; i < len; ++i) + { + // Box + size_t k = i * 6; + + //if (i == 0) + // std::cout << i << ": " << output[k + 0] << " " << output[k + 1] << " " << output[k + 2] << " " << output[k + 3] << " " << output[k + 4] << " " << output[k + 5] << std::endl; + + float x = fw * output[k + 0]; + float y = fh * output[k + 1]; + float width = fw * (output[k + 2] - output[k + 0]); + float height = fh * (output[k + 3] - output[k + 1]); + float objectConf = output[k + 4]; + int classId = cvRound(output[k + 5]); + //if (i == 0) + // std::cout << i << ": object_conf = " << objectConf << ", classId = " << classId << ", rect = " << cv::Rect(cvRound(x), cvRound(y), cvRound(width), cvRound(height)) << std::endl; + + if (objectConf >= m_params.confThreshold) + { + classIds.push_back(classId); + confidences.push_back(objectConf); + + // (center x, center y, width, height) to (x, y, w, h) + rectBoxes.emplace_back(cvRound(x), cvRound(y), cvRound(width), cvRound(height)); + } + } + + // Non-maximum suppression to eliminate redudant overlapping boxes + std::vector indices; + cv::dnn::NMSBoxes(rectBoxes, confidences, m_params.confThreshold, m_params.nmsThreshold, indices); + resBoxes.reserve(indices.size()); + + for (size_t bi = 0; bi < indices.size(); ++bi) + { + resBoxes.emplace_back(classIds[indices[bi]], confidences[indices[bi]], rectBoxes[indices[bi]]); + } + + return resBoxes; + } +}; diff --git a/src/Detector/tensorrt_yolo/class_detector.cpp b/src/Detector/tensorrt_yolo/class_detector.cpp index d64795b2..d8d58b18 100644 --- a/src/Detector/tensorrt_yolo/class_detector.cpp +++ b/src/Detector/tensorrt_yolo/class_detector.cpp @@ -8,6 +8,7 @@ #include "YoloONNXv8_bb.hpp" #include "YoloONNXv8_instance.hpp" #include "YoloONNXv9_bb.hpp" +#include "YoloONNXv10_bb.hpp" namespace tensor_rt { @@ -98,6 +99,11 @@ namespace tensor_rt m_params.outputTensorNames.push_back("output0"); m_detector = std::make_unique(); break; + case ModelType::YOLOV10: + m_params.inputTensorNames.push_back("images"); + m_params.outputTensorNames.push_back("output0"); + m_detector = std::make_unique(); + break; } // Threshold values @@ -181,7 +187,7 @@ namespace tensor_rt if (config.net_type == ModelType::YOLOV6 || config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask || - config.net_type == ModelType::YOLOV9) + config.net_type == ModelType::YOLOV9 || config.net_type == ModelType::YOLOV10) m_impl = new YoloONNXImpl(); else m_impl = new YoloDectectorImpl(); diff --git a/src/Detector/tensorrt_yolo/class_detector.h b/src/Detector/tensorrt_yolo/class_detector.h index a6e869b6..4c8e2911 100644 --- a/src/Detector/tensorrt_yolo/class_detector.h +++ b/src/Detector/tensorrt_yolo/class_detector.h @@ -52,7 +52,8 @@ namespace tensor_rt YOLOV7Mask, YOLOV8, YOLOV8Mask, - YOLOV9 + YOLOV9, + YOLOV10 }; /// diff --git a/src/Detector/tensorrt_yolo/ds_image.cpp b/src/Detector/tensorrt_yolo/ds_image.cpp index 1d1c4d9a..82c69b2a 100644 --- a/src/Detector/tensorrt_yolo/ds_image.cpp +++ b/src/Detector/tensorrt_yolo/ds_image.cpp @@ -50,7 +50,7 @@ DsImage::DsImage(const cv::Mat& mat_image_, tensor_rt::ModelType net_type, const if (tensor_rt::ModelType::YOLOV5 == net_type || tensor_rt::ModelType::YOLOV6 == net_type || tensor_rt::ModelType::YOLOV7 == net_type || tensor_rt::ModelType::YOLOV7Mask == net_type || tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || - tensor_rt::ModelType::YOLOV9 == net_type) + tensor_rt::ModelType::YOLOV9 == net_type || tensor_rt::ModelType::YOLOV10 == net_type) { // resize the DsImage with scale float r = std::min(static_cast(inputH) / static_cast(m_Height), static_cast(inputW) / static_cast(m_Width)); @@ -101,7 +101,7 @@ DsImage::DsImage(const std::string& path, tensor_rt::ModelType net_type, const i if (tensor_rt::ModelType::YOLOV5 == net_type || tensor_rt::ModelType::YOLOV6 == net_type || tensor_rt::ModelType::YOLOV7 == net_type || tensor_rt::ModelType::YOLOV7Mask == net_type || tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || - tensor_rt::ModelType::YOLOV9 == net_type) + tensor_rt::ModelType::YOLOV9 == net_type || tensor_rt::ModelType::YOLOV10 == net_type) { // resize the DsImage with scale float dim = std::max(m_Height, m_Width);