Add TopK layer support for pnnx ONNX export path#6558
Open
vlordier wants to merge 22 commits intoTencent:masterfrom
Open
Add TopK layer support for pnnx ONNX export path#6558vlordier wants to merge 22 commits intoTencent:masterfrom
vlordier wants to merge 22 commits intoTencent:masterfrom
Conversation
Member
Contributor
There was a problem hiding this comment.
Pull request overview
This pull request adds TopK layer support for the pnnx ONNX export path in ncnn, addressing issue #6377 which reports failures in ONNX export pipelines for models like YOLOv10 through ultralytics.
Changes:
- Implements native TopK layer in ncnn core with support for different axes, largest/smallest selection, and sorted/unsorted output
- Adds pnnx pass to convert ONNX TopK nodes to ncnn TopK layer with proper batch axis handling
- Adds comprehensive test coverage for 1D, 2D, 3D, and 4D tensors with various axis configurations
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| src/layer/topk.h | Header file defining TopK layer class with parameters for axis, largest, sorted, and k |
| src/layer/topk.cpp | Implementation of TopK layer supporting 1-4D tensors with efficient sorting using std::partial_sort and std::nth_element |
| src/CMakeLists.txt | Registers TopK layer in build system (correctly ordered alphabetically) |
| tools/pnnx/src/pass_ncnn/TopK.cpp | PNNX pass to convert ONNX TopK operations to ncnn format, handling batch axis removal and parameter mapping |
| tools/pnnx/src/CMakeLists.txt | Adds TopK pass to pnnx build (correctly ordered alphabetically) |
| tests/test_topk.cpp | Test cases covering 1D through 4D tensors with various axis, k, and largest parameters |
| tests/CMakeLists.txt | Registers TopK test (minor ordering issue - should come after Tile) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Please enable github action in YOUR FORKED REPO to make code-format workflow work |
- Generate TopK class definition in pnnx.py output with forward() method - Instantiate TopK modules in Model.__init__() with proper parameters - Update forward() method to call self.topk_name() instead of direct TopK() calls - Fixes pnnx inference to properly execute TopK operations using torch.topk() - Test confirms TopK ONNX→pnnx conversion and inference working correctly
- Fix IR pattern syntax to use explicit parameter names (axis=%, largest=%, sorted=%) - Replace incorrect parameter lookup from 'op_0.axis' to 'axis' to match captured names - TopK pass now properly fires during ONNX→pnnx→ncnn conversion - All TopK parameters (axis, largest, sorted) correctly captured and set in ncnn layers - End-to-end test confirms ONNX→pnnx→ncnn conversion with TopK working correctly
use c++03-style topk comparator and keep deterministic nan/inf ordering remove redundant constructor param initialization fix tests cmakelists alphabetical order (Tile before TopK) expand torch_topk onnx tests (k=0/k=1, negative dim, sorted=false cases) drop generated topk onnx/pnnx/ncnn sidecar artifacts from repo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
TopKlayer implementation in ncnn coreTopKnodes in pnnx ncnn pass to the newTopKlayertest_topkcoverageWhy
Issue #6377 reports ONNX export pipelines (for example YOLOv10 via ultralytics) failing due to unsupported
TopKin ncnn conversion/runtime path.Validation
test_topklocally./build/tests/test_topksuccessfullyFixes #6377