From 8529f128a70788ae25b6c659493c2bb012091007 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Fri, 27 Dec 2024 17:24:53 +0800 Subject: [PATCH] Add more control options (#1283) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add alignment for buffer schedule * exclude gather from supported list for xpu * add HierarchyKind * swap unpack and reduce --------- Co-authored-by: 郑启航 <597323109@qq.com> Co-authored-by: xhuohai --- .../Passes/Distributed/AutoDistributed.cs | 19 ++++-- .../Nncase.Modules.CPU/Passes/PassUtility.cs | 9 ++- .../Passes/Rules/CPU/FoldMatmulReduce.cs | 30 +++++++++ .../Passes/Rules/CPU/GraphPartition.cs | 4 +- .../Targets/CPUTargetOptions.cs | 7 ++ .../Targets/CPUTargetOptionsCommand.cs | 11 +++- python/_nncase.pyi | 12 ++-- python/nncase/__init__.py | 2 +- python/nncase/native/ffi.cpp | 16 +++-- src/Native/include/nncase/compiler.h | 44 ++++++++++--- src/Nncase.Compiler/Interop/CApi.cs | 44 +++++++++++-- src/Nncase.Core/DistributedType.cs | 10 ++- src/Nncase.Evaluator/Math/MatMul.cs | 6 ++ .../BufferSchedule/BufferScheduler.cs | 6 +- src/Nncase.Schedule/Schedule/GraphTiler.cs | 6 +- .../TileGraph/PrimGraphSolveResult.cs | 13 +++- tests/config.toml | 3 +- tests/test_runner.py | 9 ++- .../stackvm_gen/CApiGen/Templates/CApi.razor | 64 ++++++++++--------- .../CApiGen/Templates/Compiler.razor | 13 +++- .../CApiGen/Templates/PyBind.razor | 4 ++ tools/stackvm_gen/CApiGen/packages.lock.json | 2 +- 22 files changed, 257 insertions(+), 77 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index c9133c00be..baf55af9fb 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -38,14 +38,17 @@ public sealed partial class AutoDistributedPass : FunctionPass { private readonly CompileOptions _compileOptions; - public AutoDistributedPass(CompileOptions compileOptions) + private readonly string _moduleKind; + + public AutoDistributedPass(CompileOptions compileOptions, string moduleKind = "cpu") { _compileOptions = compileOptions; + _moduleKind = moduleKind; } protected override Task RunCoreAsync(BaseFunction input, RunPassContext context) { - var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions()); + var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind); return Task.FromResult(rewriter.Rewirte(input)); } } @@ -54,19 +57,23 @@ internal sealed class AutoDistributedRewriter : ExprVisitor _equalMemo = new(); - public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions) + private readonly string _moduleKind; + + public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions, string moduleKind = "cpu") { - Placements = targetOptions.Hierarchies.Select(h => new Placement(h, targetOptions.HierarchyNames)).ToArray(); + Placements = targetOptions.Hierarchies.Select(h => new Placement(h, targetOptions.HierarchyNames, targetOptions.HierarchyKind)).ToArray(); CompileOptions = compileOptions; TargetOptions = targetOptions; if (Path.Exists(TargetOptions.DistributedScheme) && System.Text.Json.JsonSerializer.Deserialize(File.ReadAllText(TargetOptions.DistributedScheme)) is DistributedScheme scheme) { - Scheme = scheme.Outputs.ToDictionary(n => n.Name, n => (new IRArray(n.NdSBP), new Placement(n.Hierarchy, n.HierarchyName))); + Scheme = scheme.Outputs.ToDictionary(n => n.Name, n => (new IRArray(n.NdSBP), new Placement(n.Hierarchy, n.HierarchyName, targetOptions.HierarchyKind))); } else { Scheme = new Dictionary NdSBP, Placement Placement)>(); } + + _moduleKind = moduleKind; } public IRArray Placements { get; } @@ -332,7 +339,7 @@ protected override Dictionary> VisitLeafCall(Call expr) return new Dictionary> { { expr.CheckedType, new() { expr } } }; } - var isSupported = PassUtility.IsCpuSupported(op, expr, expr.Arguments.ToArray()); + var isSupported = PassUtility.IsCpuSupported(op, expr, expr.Arguments.ToArray(), _moduleKind); foreach (var param in op.Parameters) { VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index], isSupported); diff --git a/modules/Nncase.Modules.CPU/Passes/PassUtility.cs b/modules/Nncase.Modules.CPU/Passes/PassUtility.cs index fe9b277666..b35f347d4a 100644 --- a/modules/Nncase.Modules.CPU/Passes/PassUtility.cs +++ b/modules/Nncase.Modules.CPU/Passes/PassUtility.cs @@ -22,7 +22,7 @@ public static bool IsCpuSupported(Op op) return op is IR.Math.Unary or IR.Math.Binary { BinaryOp: BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div } or IR.Math.MatMul or IR.NN.Conv2D { PadMode: PadMode.Constant } or IR.NN.Softmax or IR.NN.LayerNorm or IR.NN.InstanceNormalization or IR.Imaging.ResizeImage { IsTFResize: false } or IR.Tensors.Unsqueeze or IR.Tensors.Reshape or IR.Tensors.Slice or IR.Tensors.Concat or IR.Tensors.Transpose or IR.NN.Swish or IR.Tensors.Gather or IR.NN.Pad { PadMode: PadMode.Constant } or IR.Math.Reduce or IR.Math.ReduceArg or IR.Math.Clamp or IR.NN.Erf or IR.Tensors.Cast or IR.Tensors.Expand or IR.Tensors.Where or IR.Math.Compare or IR.Tensors.ScatterND; } - public static bool IsCpuSupported(Op op, Expr expr, IEnumerable arguments) + public static bool IsCpuSupported(Op op, Expr expr, IEnumerable arguments, string moduleKind = "cpu") { if (!IsCpuSupported(op)) { @@ -110,6 +110,13 @@ public static bool IsCpuSupported(Op op, Expr expr, IEnumerable arguments) return false; } + break; + case IR.Tensors.Gather gather: + if (moduleKind == "xpu") + { + return false; + } + break; default: break; diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs index b9eb6b9b3d..a800d4a6cd 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs @@ -34,3 +34,33 @@ public sealed partial class FoldPackedMatmulReduce : IRewriteRule return null; } } + +[RuleGenerator] +public sealed partial class SwapUnpackReduce : IRewriteRule +{ + public IPattern Pattern { get; } = + IsBoxing( + target_name: "boxing", + op => op.NewType is DistributedType dt && dt.NdSBP.All(s => s != SBP.P), + IsUnpack( + target_name: "unpack", + _ => true, + IsPackedMatMul( + "mm", + "call", + _ => true, + IsWildcard("lhs"), + IsWildcard("rhs")))); + + public Expr? GetReplace(Call call, Boxing boxing, Unpack unpack) + { + if (call.CheckedType is DistributedType dt && dt.NdSBP.Any(s => s == SBP.P)) + { + var newType = new DistributedType(dt.TensorType, dt.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), dt.Placement); + var newBoxing = IR.F.CPU.Boxing(call, newType, boxing.IsReshape); + return IR.F.CPU.Unpack(newBoxing, [.. unpack.Lanes], [.. unpack.Axes]); + } + + return null; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs index 791fdadaf4..662371143c 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs @@ -46,7 +46,7 @@ public CPUOutputBoxingFusion(string moduleKind) private Call? GetReplace(Call call, Op op, Boxing boxing, IReadOnlyList callParams) { - if (!PassUtility.IsCpuSupported(op, call, callParams)) + if (!PassUtility.IsCpuSupported(op, call, callParams, ModuleKind)) { return null; } @@ -142,7 +142,7 @@ public CPUSingleFusion(string moduleKind) private Call? GetReplace(Call call, Op op, IReadOnlyList callParams) { - if (!PassUtility.IsCpuSupported(op, call, callParams)) + if (!PassUtility.IsCpuSupported(op, call, callParams, ModuleKind)) { return null; } diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs index 3959a13fa6..28dd50284f 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Nncase.IR; namespace Nncase.Targets; @@ -58,6 +59,12 @@ public class CpuTargetOptions : ICpuTargetOptions [CommandLine.FromAmong(NocArchitecture.Mesh, NocArchitecture.CrossBar)] public NocArchitecture NocArch { get; set; } = NocArchitecture.Mesh; + [DisplayName("--hierarchy-kind")] + [Description("Hierarchy Kind.")] + [DefaultValue(HierarchyKind.Parallel)] + [CommandLine.FromAmong(HierarchyKind.Parallel, HierarchyKind.SMT)] + public HierarchyKind HierarchyKind { get; set; } = HierarchyKind.Parallel; + [DisplayName("--hierarchies")] [Description("the distributed hierarchies of hardware. eg. `8,4 4,8` for dynamic cluster search or `4` for fixed hardware.")] [DefaultValue("() => new int[][] { new int[] { 1 } }")] diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs index f1b04235de..6761b93f28 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -/* This file is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ +/* This file is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ using System; using System.Collections.Generic; @@ -13,6 +13,7 @@ using System.Threading.Tasks; using Nncase; using Nncase.CommandLine; +using Nncase.IR; namespace Nncase.Targets; @@ -46,6 +47,11 @@ public CpuTargetOptionsCommand(string name) description: "Noc Architecture.", getDefaultValue: () => NocArchitecture.Mesh); Add(NocArchOption); + HierarchyKindOption = new Option( + name: "--hierarchy-kind", + description: "Hierarchy Kind.", + getDefaultValue: () => HierarchyKind.Parallel); + Add(HierarchyKindOption); HierarchiesOption = new Option>( name: "--hierarchies", description: "the distributed hierarchies of hardware. eg. `8,4 4,8` for dynamic cluster search or `4` for fixed hardware.", @@ -105,6 +111,8 @@ public CpuTargetOptionsCommand(string name) public Option NocArchOption { get; } + public Option HierarchyKindOption { get; } + public Option> HierarchiesOption { get; } public Option HierarchyNamesOption { get; } @@ -138,6 +146,7 @@ public CpuTargetOptions GetBoundValue(InvocationContext context) UnifiedMemoryArch = context.ParseResult.GetValueForOption(_cmd.UnifiedMemoryArchOption)!, MemoryAccessArch = context.ParseResult.GetValueForOption(_cmd.MemoryAccessArchOption)!, NocArch = context.ParseResult.GetValueForOption(_cmd.NocArchOption)!, + HierarchyKind = context.ParseResult.GetValueForOption(_cmd.HierarchyKindOption)!, Hierarchies = context.ParseResult.GetValueForOption(_cmd.HierarchiesOption)!.ToArray(), HierarchyNames = context.ParseResult.GetValueForOption(_cmd.HierarchyNamesOption)!, HierarchySizes = context.ParseResult.GetValueForOption(_cmd.HierarchySizesOption)!.ToArray(), diff --git a/python/_nncase.pyi b/python/_nncase.pyi index 4f43670628..d72c421239 100644 --- a/python/_nncase.pyi +++ b/python/_nncase.pyi @@ -3,16 +3,19 @@ from typing import Any, List, BinaryIO, Enum import numpy -""" This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. """ +""" This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ class MemoryAccessArchitecture(Enum): UMA = 0 NUMA = 1 class NocArchitecture(Enum): Mesh = 0 CrossBar = 1 -""" end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. """ +class HierarchyKind(Enum): + Parallel = 0 + SMT = 1 +""" end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ -""" This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. """ +""" This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ class CpuTargetOptions: def __init__(self) -> None: ... ModelName: str @@ -20,6 +23,7 @@ class CpuTargetOptions: UnifiedMemoryArch: bool MemoryAccessArch: MemoryAccessArchitecture NocArch: NocArchitecture + HierarchyKind: HierarchyKind Hierarchies: List[List[int]] HierarchyNames: str HierarchySizes: List[int] @@ -27,7 +31,7 @@ class CpuTargetOptions: MemoryBandWidths: List[int] DistributedScheme: str CustomOpScheme: str -""" end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. """ +""" end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ class CompileOptions: benchmark_only: bool diff --git a/python/nncase/__init__.py b/python/nncase/__init__.py index f6aebae307..fa9f8a38ab 100644 --- a/python/nncase/__init__.py +++ b/python/nncase/__init__.py @@ -33,7 +33,7 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' import _nncase -from _nncase import RuntimeTensor, TensorDesc, Simulator, CpuTargetOptions, NocArchitecture, MemoryAccessArchitecture +from _nncase import RuntimeTensor, TensorDesc, Simulator, CpuTargetOptions, NocArchitecture, HierarchyKind, MemoryAccessArchitecture def _initialize(): diff --git a/python/nncase/native/ffi.cpp b/python/nncase/native/ffi.cpp index b689db7ff2..e099c0571d 100644 --- a/python/nncase/native/ffi.cpp +++ b/python/nncase/native/ffi.cpp @@ -235,7 +235,7 @@ PYBIND11_MODULE(_nncase, m) { &shape_bucket_options::fix_var_map)); // clang-format off - /* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ py::enum_(m, "MemoryAccessArchitecture") .value("UMA", memory_access_architecture_uma) @@ -245,6 +245,10 @@ PYBIND11_MODULE(_nncase, m) { .value("Mesh", noc_architecture_mesh) .value("CrossBar", noc_architecture_cross_bar); + py::enum_(m, "HierarchyKind") + .value("Parallel", hierarchy_kind_parallel) + .value("SMT", hierarchy_kind_smt); + py::class_(m, "CpuTargetOptions") .def(py::init()) @@ -262,12 +266,16 @@ PYBIND11_MODULE(_nncase, m) { py::overload_cast(&cpu_target_options::unified_memory_arch)) .def_property( "MemoryAccessArch", - []() {}, + py::overload_cast<>(&cpu_target_options::memory_access_arch), py::overload_cast(&cpu_target_options::memory_access_arch)) .def_property( "NocArch", - []() {}, + py::overload_cast<>(&cpu_target_options::noc_arch), py::overload_cast(&cpu_target_options::noc_arch)) + .def_property( + "HierarchyKind", + py::overload_cast<>(&cpu_target_options::hierarchy_kind), + py::overload_cast(&cpu_target_options::hierarchy_kind)) .def_property( "Hierarchies", []() {}, @@ -297,7 +305,7 @@ PYBIND11_MODULE(_nncase, m) { []() {}, py::overload_cast(&cpu_target_options::custom_op_scheme)) ; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ // clang-format on py::class_(m, "CalibrationDatasetProvider") diff --git a/src/Native/include/nncase/compiler.h b/src/Native/include/nncase/compiler.h index eec4da7873..79e7641974 100644 --- a/src/Native/include/nncase/compiler.h +++ b/src/Native/include/nncase/compiler.h @@ -81,16 +81,20 @@ typedef enum { } nncase_input_type_t; // clang-format off -/* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ -typedef enum { +/* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ +enum memory_access_architecture_t : uint8_t { memory_access_architecture_uma = 0, memory_access_architecture_numa = 1, -} memory_access_architecture_t; -typedef enum { +}; +enum noc_architecture_t : uint8_t { noc_architecture_mesh = 0, noc_architecture_cross_bar = 1, -} noc_architecture_t; -/* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ +}; +enum hierarchy_kind_t : uint8_t { + hierarchy_kind_parallel = 0, + hierarchy_kind_smt = 1, +}; +/* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ // clang-format on typedef struct { @@ -237,13 +241,17 @@ typedef struct { clr_object_handle_t shape_bucket_options, const char *fix_var_map, size_t fix_var_map_size); // clang-format off - /* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ clr_object_handle_t (*cpu_target_options_create)(); void (*cpu_target_options_set_model_name)(clr_object_handle_t handle, const char* value, size_t length); void (*cpu_target_options_set_packing)(clr_object_handle_t handle, uint8_t value); void (*cpu_target_options_set_unified_memory_arch)(clr_object_handle_t handle, uint8_t value); + uint8_t (*cpu_target_options_get_memory_access_arch)(clr_object_handle_t handle); void (*cpu_target_options_set_memory_access_arch)(clr_object_handle_t handle, uint8_t value); + uint8_t (*cpu_target_options_get_noc_arch)(clr_object_handle_t handle); void (*cpu_target_options_set_noc_arch)(clr_object_handle_t handle, uint8_t value); + uint8_t (*cpu_target_options_get_hierarchy_kind)(clr_object_handle_t handle); + void (*cpu_target_options_set_hierarchy_kind)(clr_object_handle_t handle, uint8_t value); void (*cpu_target_options_set_hierarchies)(clr_object_handle_t handle, int32_t* value, size_t shape0, size_t* shape1); void (*cpu_target_options_set_hierarchy_names)(clr_object_handle_t handle, const char* value, size_t length); void (*cpu_target_options_set_hierarchy_sizes)(clr_object_handle_t handle, int32_t* value, size_t shape0); @@ -251,7 +259,7 @@ typedef struct { void (*cpu_target_options_set_memory_band_widths)(clr_object_handle_t handle, int32_t* value, size_t shape0); void (*cpu_target_options_set_distributed_scheme)(clr_object_handle_t handle, const char* value, size_t length); void (*cpu_target_options_set_custom_op_scheme)(clr_object_handle_t handle, const char* value, size_t length); - /* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ // clang-format on clr_object_handle_t (*rtvalue_from_handle)(nncase::value_node *value); @@ -495,7 +503,7 @@ class shape_bucket_options : public clr_object_base { }; // clang-format off -/* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ +/* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ class cpu_target_options : public clr_object_base { public: using clr_object_base::clr_object_base; @@ -516,14 +524,30 @@ class cpu_target_options : public clr_object_base { nncase_clr_api()->cpu_target_options_set_unified_memory_arch(obj_.get(), value); } + memory_access_architecture_t memory_access_arch() { + return (memory_access_architecture_t)nncase_clr_api()->cpu_target_options_get_memory_access_arch(obj_.get()); + } + void memory_access_arch(memory_access_architecture_t value) { nncase_clr_api()->cpu_target_options_set_memory_access_arch(obj_.get(), value); } + noc_architecture_t noc_arch() { + return (noc_architecture_t)nncase_clr_api()->cpu_target_options_get_noc_arch(obj_.get()); + } + void noc_arch(noc_architecture_t value) { nncase_clr_api()->cpu_target_options_set_noc_arch(obj_.get(), value); } + hierarchy_kind_t hierarchy_kind() { + return (hierarchy_kind_t)nncase_clr_api()->cpu_target_options_get_hierarchy_kind(obj_.get()); + } + + void hierarchy_kind(hierarchy_kind_t value) { + nncase_clr_api()->cpu_target_options_set_hierarchy_kind(obj_.get(), value); + } + void hierarchies(std::vector> value) { std::vector values; size_t shape0; @@ -584,7 +608,7 @@ class cpu_target_options : public clr_object_base { nncase_clr_api()->cpu_target_options_set_custom_op_scheme(obj_.get(), value.data(), value.length()); } }; -/* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ +/* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ // clang-format on class cstream : public clr_object_base { diff --git a/src/Nncase.Compiler/Interop/CApi.cs b/src/Nncase.Compiler/Interop/CApi.cs index 2b2fe0b912..a406d2ace9 100644 --- a/src/Nncase.Compiler/Interop/CApi.cs +++ b/src/Nncase.Compiler/Interop/CApi.cs @@ -96,13 +96,17 @@ public unsafe struct CApiMT public delegate* unmanaged ShapeBucketOptionsSetRangeInfoPtr; public delegate* unmanaged ShapeBucketOptionsSetSegmentsCountPtr; public delegate* unmanaged ShapeBucketOptionsSetFixVarMapPtr; - /* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:31:31 PM +08:00. */ public delegate* unmanaged CpuTargetOptionsCreatePtr; public delegate* unmanaged CpuTargetOptionsSetModelNamePtr; public delegate* unmanaged CpuTargetOptionsSetPackingPtr; public delegate* unmanaged CpuTargetOptionsSetUnifiedMemoryArchPtr; + public delegate* unmanaged CpuTargetOptionsGetMemoryAccessArchPtr; public delegate* unmanaged CpuTargetOptionsSetMemoryAccessArchPtr; + public delegate* unmanaged CpuTargetOptionsGetNocArchPtr; public delegate* unmanaged CpuTargetOptionsSetNocArchPtr; + public delegate* unmanaged CpuTargetOptionsGetHierarchyKindPtr; + public delegate* unmanaged CpuTargetOptionsSetHierarchyKindPtr; public delegate* unmanaged CpuTargetOptionsSetHierarchiesPtr; public delegate* unmanaged CpuTargetOptionsSetHierarchyNamesPtr; public delegate* unmanaged CpuTargetOptionsSetHierarchySizesPtr; @@ -110,7 +114,7 @@ public unsafe struct CApiMT public delegate* unmanaged CpuTargetOptionsSetMemoryBandWidthsPtr; public delegate* unmanaged CpuTargetOptionsSetDistributedSchemePtr; public delegate* unmanaged CpuTargetOptionsSetCustomOpSchemePtr; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 5:31:31 PM +08:00. */ public delegate* unmanaged RTValueFromHandlePtr; public delegate* unmanaged RTValueGetHandlePtr; public delegate* unmanaged StreamCreatePtr; @@ -184,13 +188,17 @@ public static void Initialize(CApiMT* mt) mt->ShapeBucketOptionsSetRangeInfoPtr = &ShapeBucketOptionsSetRangeInfo; mt->ShapeBucketOptionsSetSegmentsCountPtr = &ShapeBucketOptionsSetSegmentsCount; mt->ShapeBucketOptionsSetFixVarMapPtr = &ShapeBucketOptionsSetFixVarMap; - /* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ mt->CpuTargetOptionsCreatePtr = &CpuTargetOptionsCreate; mt->CpuTargetOptionsSetModelNamePtr = &CpuTargetOptionsSetModelName; mt->CpuTargetOptionsSetPackingPtr = &CpuTargetOptionsSetPacking; mt->CpuTargetOptionsSetUnifiedMemoryArchPtr = &CpuTargetOptionsSetUnifiedMemoryArch; + mt->CpuTargetOptionsGetMemoryAccessArchPtr = &CpuTargetOptionsGetMemoryAccessArch; mt->CpuTargetOptionsSetMemoryAccessArchPtr = &CpuTargetOptionsSetMemoryAccessArch; + mt->CpuTargetOptionsGetNocArchPtr = &CpuTargetOptionsGetNocArch; mt->CpuTargetOptionsSetNocArchPtr = &CpuTargetOptionsSetNocArch; + mt->CpuTargetOptionsGetHierarchyKindPtr = &CpuTargetOptionsGetHierarchyKind; + mt->CpuTargetOptionsSetHierarchyKindPtr = &CpuTargetOptionsSetHierarchyKind; mt->CpuTargetOptionsSetHierarchiesPtr = &CpuTargetOptionsSetHierarchies; mt->CpuTargetOptionsSetHierarchyNamesPtr = &CpuTargetOptionsSetHierarchyNames; mt->CpuTargetOptionsSetHierarchySizesPtr = &CpuTargetOptionsSetHierarchySizes; @@ -198,7 +206,7 @@ public static void Initialize(CApiMT* mt) mt->CpuTargetOptionsSetMemoryBandWidthsPtr = &CpuTargetOptionsSetMemoryBandWidths; mt->CpuTargetOptionsSetDistributedSchemePtr = &CpuTargetOptionsSetDistributedScheme; mt->CpuTargetOptionsSetCustomOpSchemePtr = &CpuTargetOptionsSetCustomOpScheme; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ mt->RTValueFromHandlePtr = &RTValueFromHandle; mt->RTValueGetHandlePtr = &RTValueGetHandle; mt->StreamCreatePtr = &StreamCreate; @@ -780,7 +788,7 @@ private static void ShapeBucketOptionsSetFixVarMap(IntPtr shapeBucketOptionsHand Get(shapeBucketOptionsHandle).FixVarMap = fixVarMapStruct; } - /* This block is generated by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:31:31 PM +08:00. */ [UnmanagedCallersOnly] private static IntPtr CpuTargetOptionsCreate() { @@ -805,18 +813,42 @@ private static void CpuTargetOptionsSetUnifiedMemoryArch(IntPtr handle, byte val Get(handle).UnifiedMemoryArch = value != 0; } + [UnmanagedCallersOnly] + private static MemoryAccessArchitecture CpuTargetOptionsGetMemoryAccessArch(IntPtr handle) + { + return Get(handle).MemoryAccessArch; + } + [UnmanagedCallersOnly] private static void CpuTargetOptionsSetMemoryAccessArch(IntPtr handle, MemoryAccessArchitecture value) { Get(handle).MemoryAccessArch = value; } + [UnmanagedCallersOnly] + private static NocArchitecture CpuTargetOptionsGetNocArch(IntPtr handle) + { + return Get(handle).NocArch; + } + [UnmanagedCallersOnly] private static void CpuTargetOptionsSetNocArch(IntPtr handle, NocArchitecture value) { Get(handle).NocArch = value; } + [UnmanagedCallersOnly] + private static HierarchyKind CpuTargetOptionsGetHierarchyKind(IntPtr handle) + { + return Get(handle).HierarchyKind; + } + + [UnmanagedCallersOnly] + private static void CpuTargetOptionsSetHierarchyKind(IntPtr handle, HierarchyKind value) + { + Get(handle).HierarchyKind = value; + } + [UnmanagedCallersOnly] private static void CpuTargetOptionsSetHierarchies(IntPtr handle, int* value, nuint shape0, nuint* shape1) { @@ -859,7 +891,7 @@ private static void CpuTargetOptionsSetCustomOpScheme(IntPtr handle, byte* value Get(handle).CustomOpScheme = ToString(value, length); } - /* end the auto generated block by tools/stackvm_gen/CApiGen at 10/25/2024 6:12:16 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ [UnmanagedCallersOnly] private static IntPtr RTValueFromHandle(IntPtr handle) diff --git a/src/Nncase.Core/DistributedType.cs b/src/Nncase.Core/DistributedType.cs index d3269dbf85..0d3415a6f6 100644 --- a/src/Nncase.Core/DistributedType.cs +++ b/src/Nncase.Core/DistributedType.cs @@ -14,6 +14,12 @@ namespace Nncase.IR; +public enum HierarchyKind : byte +{ + Parallel = 0, + SMT = 1, +} + [JsonDerivedType(typeof(SBPSplit), "S")] [JsonDerivedType(typeof(SBPPartialSum), "P")] [JsonDerivedType(typeof(SBPBroadCast), "B")] @@ -45,8 +51,8 @@ public sealed record SBPBroadCast : SBP public override string ToString() => "B"; } -// public sealed record Placement(Placement.DeviceKind Kind, IRArray Hierarchy, string Name) -public sealed record Placement(IRArray Hierarchy, string Name) +// public sealed record Placement(Placement.DeviceKind Kind, IRArray Hierarchy, string Name, HierarchyKind HierarchyKind) +public sealed record Placement(IRArray Hierarchy, string Name, HierarchyKind HierarchyKind = HierarchyKind.Parallel) { // public enum DeviceKind : uint // { diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index c76177b4b8..601e5205a7 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -50,6 +50,12 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b, if (ax == lk && bx == rk) { // split on k + if (a.Placement.HierarchyKind == HierarchyKind.SMT && i == a.Placement.Rank - 1) + { + // not split k on threads + return invalid; + } + ndsbp[i] = SBP.P; } else if ((ax == lk && bx != rk) || (ax != lk && bx == rk) || (ax == lm && bx == rn)) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index 01aebbdde4..b04a6d9729 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -31,9 +31,10 @@ public ConstrainEventArgs(CpModel model, IReadOnlyDictionary buffers) { foreach (var getItem in buffers.Keys.OfType().Where(e => e is Call { Target: IR.Tensors.GetItem } && buffers[e].MemInterval.Size == 0)) @@ -117,6 +120,7 @@ public void Schedule(IReadOnlyDictionary bufferMap) } var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start"); + model.AddModuloEquality(0, memStartVar, Alignment); var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.MemInterval.Stop, $"{item.Name}_{item.Number}_y"); noOverlap.AddRectangle(xInterval, yInterval); yStarts.Add(memStartVar); diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs index 21a8435b7c..4777aaec9a 100644 --- a/src/Nncase.Schedule/Schedule/GraphTiler.cs +++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs @@ -72,7 +72,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO if (!_primFuncMemo.TryGetValue(primTree, out var wrapper)) { - var result = SolvePrimGraph(primTree, primBufferGraph, targetOptions); + var result = SolvePrimGraph(primTree, primBufferGraph, targetOptions, moduleKind); (inputBids, outputBids) = (result.Inputs, result.Outputs); result.ScheduleBuffers(); var bodyBuilder = T.Sequential(); @@ -135,7 +135,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO // return new Call(None.Default); } - private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBufferGraph, ICpuTargetOptions targetOptions) + private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBufferGraph, ICpuTargetOptions targetOptions, string moduleKind) { int[] memoryCapacities = targetOptions.MemoryCapacities; int[] memoryBandWidths = targetOptions.MemoryBandWidths; @@ -549,7 +549,7 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer DumpAssgin(primTree, new TreeSolverPythonPrinter(sol, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, targetOptions), tileVarConstraints, eachLevelStoreBufferNumsConstrains, levelBufferSizes, levelDataReads, levelDataWrites, memoryCycles, computeCycles, totalCyclesVar); } - return new TreeSolveResult(primBufferGraph, sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions); + return new TreeSolveResult(primBufferGraph, sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions, moduleKind); } private void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Dictionary tileVarConstraints, Dictionary lowestStoreBufferNumsConstrains, Dictionary> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles) diff --git a/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs b/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs index 8b33268e51..5eb7c686b1 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs @@ -18,7 +18,7 @@ public sealed class TreeSolveResult : TreeSolverBase, ITreeNodeVisitor> _subViewMemo; - public TreeSolveResult(BufferGraph primBufferGraph, long objectiveValue, Dictionary> levelNodeBufferBoxs, Dictionary>> levelTreeBufferLifeness, Dictionary> primitiveBufferInfo, Dictionary> levelBufferInfos, Dictionary> domainInfos, ICpuTargetOptions targetOptions) + public TreeSolveResult(BufferGraph primBufferGraph, long objectiveValue, Dictionary> levelNodeBufferBoxs, Dictionary>> levelTreeBufferLifeness, Dictionary> primitiveBufferInfo, Dictionary> levelBufferInfos, Dictionary> domainInfos, ICpuTargetOptions targetOptions, string moduleKind) : base(null!, primitiveBufferInfo, levelBufferInfos, domainInfos, targetOptions) { PrimBufferGraph = primBufferGraph; @@ -26,6 +26,7 @@ public TreeSolveResult(BufferGraph primBufferGraph, long objectiveValue, Diction ObjectiveValue = objectiveValue; LevelBufferSizes = levelNodeBufferBoxs; LevelBufferLifeness = levelTreeBufferLifeness; + ModuleKind = moduleKind; LevelBufferOffsets = new(); PrimBufferMemo = new(); _subViewMemo = new(); @@ -47,6 +48,8 @@ public TreeSolveResult(BufferGraph primBufferGraph, long objectiveValue, Diction public Dictionary>> LevelBufferLifeness { get; } + public string ModuleKind { get; } + public Unit Visit(TileNode value, Context context) { var (parentbuilder, partentOffsets) = context; @@ -240,7 +243,13 @@ public void ScheduleBuffers() { xstarts.Add(solver.MakeIntConst(LevelBufferLifeness[level][key].Item1)); xsizes.Add(LevelBufferLifeness[level][key].Item2 - LevelBufferLifeness[level][key].Item1); - ystarts.Add(solver.MakeIntVar(0, TargetOptions.MemoryCapacities[level] - size)); + var ystart = solver.MakeIntVar(0, TargetOptions.MemoryCapacities[level] - size); + ystarts.Add(ystart); + if (ModuleKind == "xpu") + { + solver.Add(solver.MakeEquality(solver.MakeIntConst(0), solver.MakeModulo(ystart, 128))); + } + ysizes.Add(size); validKeys.Add(key); } diff --git a/tests/config.toml b/tests/config.toml index aaee0d9c71..f25539e50f 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -155,4 +155,5 @@ HierarchySizes = [268435456, 1048576] MemoryCapacities = [262144, 67108864] MemoryBandWidths = [64, 32] UnifiedMemoryArch = false -Packing = true \ No newline at end of file +Packing = true +HierarchyKind = "nncase.HierarchyKind.SMT" \ No newline at end of file diff --git a/tests/test_runner.py b/tests/test_runner.py index c467edfbfa..9ced2915d5 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -353,7 +353,14 @@ def get_target_options(self, target: str, values: dict) -> object: if target == 'cpu' or target == 'xpu': target_options = nncase.CpuTargetOptions() for k, v in values.items(): - exec(f"target_options.{k} = {e + v + e if isinstance(v, str) else v}") + is_enum = False + try: + exec(f"target_options.{k}") + is_enum = True + except: + pass + exec( + f"target_options.{k} = { e + v + e if isinstance(v, str) and not is_enum else v}") return target_options def get_compile_options(self, target, dump_dir): diff --git a/tools/stackvm_gen/CApiGen/Templates/CApi.razor b/tools/stackvm_gen/CApiGen/Templates/CApi.razor index 9edb3cbe7b..f5a6e2f29f 100644 --- a/tools/stackvm_gen/CApiGen/Templates/CApi.razor +++ b/tools/stackvm_gen/CApiGen/Templates/CApi.razor @@ -4,49 +4,55 @@ DisableEncoding = true; } -/* This block is generated by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ -public delegate* unmanaged @(Model.OptionsType.Name)CreatePtr; + /* This block is generated by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ + public delegate* unmanaged @(Model.OptionsType.Name)CreatePtr; @foreach (var info in Model.OptionInfos) { -@* @:public delegate* unmanaged @(Model.OptionsType.Name)Get@(info.PropertyName)Ptr; *@ -@:public delegate* unmanaged @(Model.OptionsType.Name)Set@(info.PropertyName)Ptr; +if (info.PropertyType.IsEnum) { +@: public delegate* unmanaged @(Model.OptionsType.Name)Get@(info.PropertyName)Ptr; } -/* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ +@: public delegate* unmanaged @(Model.OptionsType.Name)Set@(info.PropertyName)Ptr; +} + /* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ -/* This block is generated by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ -mt->@(Model.OptionsType.Name)CreatePtr = &@(Model.OptionsType.Name)Create; + /* This block is generated by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ + mt->@(Model.OptionsType.Name)CreatePtr = &@(Model.OptionsType.Name)Create; @foreach(var info in Model.OptionInfos) { -@* @:mt->@(Model.OptionsType.Name)Get@(info.PropertyName)Ptr = &@(Model.OptionsType.Name)Get@(info.PropertyName); *@ -@:mt->@(Model.OptionsType.Name)Set@(info.PropertyName)Ptr = &@(Model.OptionsType.Name)Set@(info.PropertyName); +if (info.PropertyType.IsEnum) { +@: mt->@(Model.OptionsType.Name)Get@(info.PropertyName)Ptr = &@(Model.OptionsType.Name)Get@(info.PropertyName); +} +@: mt->@(Model.OptionsType.Name)Set@(info.PropertyName)Ptr = &@(Model.OptionsType.Name)Set@(info.PropertyName); } -/* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ -/* This block is generated by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ -[UnmanagedCallersOnly] -private static IntPtr @(Model.OptionsType.Name)Create() -{ - return GCHandle.ToIntPtr(GCHandle.Alloc(new @(Model.OptionsType.Name)())); -} + /* This block is generated by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ + [UnmanagedCallersOnly] + private static IntPtr @(Model.OptionsType.Name)Create() + { + return GCHandle.ToIntPtr(GCHandle.Alloc(new @(Model.OptionsType.Name)())); + } @foreach (var info in Model.OptionInfos) { -@* @: -@:[UnmanagedCallersOnly] -@:private static void @(Model.OptionsType.Name)Get@(info.PropertyName)(IntPtr handle) -@:{ -@: return Get<@(Model.OptionsType.Name)>(handle).@(info.PropertyName); -@:} *@ +if (info.PropertyType.IsEnum) { @: -@:[UnmanagedCallersOnly] -@:private static void @(Model.OptionsType.Name)Set@(info.PropertyName)(IntPtr handle, @(info.RenderSignature(SignMode.Type | SignMode.Param, LangMode.UnCS))) -@:{ -@: Get<@(Model.OptionsType.Name)>(handle).@(info.PropertyName) = @(info.RenderUnCSAssginValue()); -@:} +@: [UnmanagedCallersOnly] +@: private static @(info.RenderSignature(SignMode.Type, LangMode.UnCS)) @(Model.OptionsType.Name)Get@(info.PropertyName)(IntPtr handle) +@: { +@: return Get<@(Model.OptionsType.Name)>(handle).@(info.PropertyName); +@: } } - -/* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ \ No newline at end of file +@: +@: [UnmanagedCallersOnly] +@: private static void @(Model.OptionsType.Name)Set@(info.PropertyName)(IntPtr handle, @(info.RenderSignature(SignMode.Type | SignMode.Param, LangMode.UnCS))) +@: { +@: Get<@(Model.OptionsType.Name)>(handle).@(info.PropertyName) = @(info.RenderUnCSAssginValue()); +@: } + } + + /* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ \ No newline at end of file diff --git a/tools/stackvm_gen/CApiGen/Templates/Compiler.razor b/tools/stackvm_gen/CApiGen/Templates/Compiler.razor index fd0990b486..20eceb820a 100644 --- a/tools/stackvm_gen/CApiGen/Templates/Compiler.razor +++ b/tools/stackvm_gen/CApiGen/Templates/Compiler.razor @@ -13,11 +13,11 @@ if (info.PropertyType.IsEnum) { var enumName = info.PropertyType.Name.ToSnake(); -@:typedef enum { +@:enum @(enumName)_t : uint8_t { foreach(var (name, value) in info.PropertyType.RenderEnumFields(LangMode.UnCPP)){ @: @(enumName)_@(name.ToSnake()) = @(value), } -@:} @(enumName)_t; +@:}; } } /* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ @@ -29,6 +29,9 @@ foreach(var (name, value) in info.PropertyType.RenderEnumFields(LangMode.UnCPP)) clr_object_handle_t (*@(optionsName)_create)(); @foreach (var info in Model.OptionInfos) { + if (info.PropertyType.IsEnum) { + @:@(info.RenderSignature(SignMode.Type, LangMode.UnCPP)) (*@(optionsName)_get_@(info.PropertyName.ToSnake()))(clr_object_handle_t handle); + } @:void (*@(optionsName)_set_@(info.PropertyName.ToSnake()))(clr_object_handle_t handle, @(info.RenderSignature(SignMode.Type | SignMode.Param, LangMode.UnCPP))); } /* end the auto generated block by tools/stackvm_gen/CApiGen at @Raw(DateTimeOffset.Now.ToString()). */ @@ -46,6 +49,12 @@ class @(optionsName) : public clr_object_base { } @foreach( var info in Model.OptionInfos) { + if (info.PropertyType.IsEnum) { +@: +@: @(info.RenderSignature(SignMode.Type, LangMode.Pyb)) @(info.PropertyName.ToSnake())() { +@: return (@(info.PropertyType.Name.ToSnake())_t)nncase_clr_api()->@(optionsName)_get_@(info.PropertyName.ToSnake())(obj_.get()); +@: } + } @: @: void @(info.PropertyName.ToSnake())(@(info.RenderSignature(SignMode.Type | SignMode.Param, LangMode.Pyb))) { if (info.PropertyType.IsNestedArrayType(out var stacks) && stacks.Count > 1) diff --git a/tools/stackvm_gen/CApiGen/Templates/PyBind.razor b/tools/stackvm_gen/CApiGen/Templates/PyBind.razor index 641791b3af..71d3de7775 100644 --- a/tools/stackvm_gen/CApiGen/Templates/PyBind.razor +++ b/tools/stackvm_gen/CApiGen/Templates/PyBind.razor @@ -29,7 +29,11 @@ var info = Model.OptionInfos[i]; @:.def_property( @: "@(info.PropertyName)", + if (info.PropertyType.IsEnum) { + @: py::overload_cast<>(&@(Model.OptionsType.Name.ToSnake())::@(info.PropertyName.ToSnake())), + } else { @: []() {}, + } @: py::overload_cast<@(info.PropertyType.RenderType(LangMode.Pyb))>(&@(Model.OptionsType.Name.ToSnake())::@(info.PropertyName.ToSnake()))) @if (i == Model.OptionInfos.Count - 1) { @(";") } } diff --git a/tools/stackvm_gen/CApiGen/packages.lock.json b/tools/stackvm_gen/CApiGen/packages.lock.json index 2869d49843..bcc9281162 100644 --- a/tools/stackvm_gen/CApiGen/packages.lock.json +++ b/tools/stackvm_gen/CApiGen/packages.lock.json @@ -295,6 +295,7 @@ "CommunityToolkit.HighPerformance": "[8.2.2, )", "DryIoc.dll": "[5.4.3, )", "GiGraph.Dot": "[3.0.1, )", + "Google.OrTools": "[9.4.1874, )", "Microsoft.Extensions.Hosting.Abstractions": "[8.0.0, )", "Microsoft.Extensions.Logging.Abstractions": "[8.0.1, )", "Microsoft.Extensions.Options": "[8.0.2, )", @@ -369,7 +370,6 @@ "nncase.schedule": { "type": "Project", "dependencies": { - "Google.OrTools": "[9.4.1874, )", "Nncase.Core": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )" }