diff --git a/src/Native/src/kernels/stackvm/reference/pad.cpp b/src/Native/src/kernels/stackvm/reference/pad.cpp index d0fc488b38..ec2d0a6e13 100644 --- a/src/Native/src/kernels/stackvm/reference/pad.cpp +++ b/src/Native/src/kernels/stackvm/reference/pad.cpp @@ -116,31 +116,35 @@ template void set_data_v(T *dst, int len, T value) { } template -void pad_data2(T *in, T *out, int cl, int dl, int hl, int wl, int ch, int dh, - int hh, int wh, T value) { +void pad_data2(T *in, T *out, int nl, int cl, int dl, int hl, int wl, int nh, + int ch, int dh, int hh, int wh, T value) { (void)ch; int blocks_in = wl; int blocks_out = wh; - for (int c = 0; c < cl; ++c) { - for (int d = 0; d < dl; ++d) { - for (int h = 0; h < hl; ++h) { - int index_out = h + d * hh + c * dh * hh; - int index_in = c * hl * dl + d * hl + h; - T *inptr = in + index_in * blocks_in; - T *outptr = out + index_out * blocks_out; - copy_data_v(inptr, outptr, blocks_in, blocks_out, value); + for (int n = 0; n < nl; ++n) { + for (int c = 0; c < cl; ++c) { + for (int d = 0; d < dl; ++d) { + for (int h = 0; h < hl; ++h) { + int index_out = h + d * hh + c * dh * hh + n * ch * dh * hh; + int index_in = n * hl * dl * cl + c * hl * dl + d * hl + h; + T *inptr = in + index_in * blocks_in; + T *outptr = out + index_out * blocks_out; + copy_data_v(inptr, outptr, blocks_in, blocks_out, value); + } } } } - for (int c = 0; c < ch; ++c) { - for (int d = 0; d < dh; ++d) { - for (int h = 0; h < hh; ++h) { - int index = h + d * hh + c * dh * hh; - T *outptr = out + index * blocks_out; - if (h >= hl || d >= dl || c >= cl) { - set_data_v(outptr, blocks_out, value); + for (int n = 0; n < nh; ++n) { + for (int c = 0; c < ch; ++c) { + for (int d = 0; d < dh; ++d) { + for (int h = 0; h < hh; ++h) { + int index = h + d * hh + c * dh * hh + n * ch * dh * hh; + T *outptr = out + index * blocks_out; + if (h >= hl || d >= dl || c >= cl) { + set_data_v(outptr, blocks_out, value); + } } } } @@ -150,48 +154,67 @@ void pad_data2(T *in, T *out, int cl, int dl, int hl, int wl, int ch, int dh, template void padding_impl_opt(T *in, T *out, gsl::span in_shape, gsl::span out_shape, T value) { - int cl, dl, hl, wl; - int ch, dh, hh, wh; + int nl, cl, dl, hl, wl; + int nh, ch, dh, hh, wh; if (in_shape.size() == 3 || (in_shape.size() == 4 && in_shape[in_shape.size() - 1] == 1)) { + nl = 1; cl = 1; dl = in_shape[0]; hl = in_shape[1]; wl = in_shape[2]; + nh = 1; ch = 1; dh = out_shape[0]; hh = out_shape[1]; wh = out_shape[2]; + } else if (in_shape.size() == 5) { + nl = in_shape[0]; + cl = in_shape[1]; + dl = in_shape[2]; + hl = in_shape[3]; + wl = in_shape[4]; + nh = out_shape[0]; + ch = out_shape[1]; + dh = out_shape[2]; + hh = out_shape[3]; + wh = out_shape[4]; } else if (in_shape.size() == 4) { + nl = 1; cl = in_shape[0]; dl = in_shape[1]; hl = in_shape[2]; wl = in_shape[3]; + nh = 1; ch = out_shape[0]; dh = out_shape[1]; hh = out_shape[2]; wh = out_shape[3]; } else if (in_shape.size() == 2) { + nl = 1; cl = 1; dl = 1; hl = in_shape[0]; wl = in_shape[1]; + nh = 1; ch = 1; dh = 1; hh = out_shape[0]; wh = out_shape[1]; } else { + nl = 1; cl = 1; dl = 1; hl = 1; wl = in_shape[0]; + nh = 1; ch = 1; dh = 1; hh = 1; wh = out_shape[1]; } - pad_data2(in, out, cl, dl, hl, wl, ch, dh, hh, wh, value); + pad_data2(in, out, nl, cl, dl, hl, wl, nh, ch, dh, hh, wh, value); } template @@ -239,9 +262,10 @@ result nncase::kernels::stackvm::reference::pad( if (std::all_of(paddings.begin(), paddings.end(), [](const padding &p) { return p.interior == 0; })) { auto out_shape = get_padded_shape(in_shape, paddings); + auto can_opt = out_shape.size() < 6; switch (unit) { case 1: - if (padding_before_is_zero) { + if (padding_before_is_zero && can_opt) { padding_impl_opt((int8_t *)input, (int8_t *)output, in_shape, out_shape, *(int8_t *)pad_value); } else { @@ -252,7 +276,7 @@ result nncase::kernels::stackvm::reference::pad( } break; case 2: - if (padding_before_is_zero) { + if (padding_before_is_zero && can_opt) { padding_impl_opt((int16_t *)input, (int16_t *)output, in_shape, out_shape, *(int16_t *)pad_value); } else { @@ -263,7 +287,7 @@ result nncase::kernels::stackvm::reference::pad( } break; case 4: - if (padding_before_is_zero) { + if (padding_before_is_zero && can_opt) { padding_impl_opt((int32_t *)input, (int32_t *)output, in_shape, out_shape, *(int32_t *)pad_value); } else { @@ -274,7 +298,7 @@ result nncase::kernels::stackvm::reference::pad( } break; case 8: - if (padding_before_is_zero) { + if (padding_before_is_zero && can_opt) { padding_impl_opt((int64_t *)input, (int64_t *)output, in_shape, out_shape, *(int64_t *)pad_value); } else { diff --git a/src/Nncase.Core/IR/PrimFunctionWrapper.cs b/src/Nncase.Core/IR/PrimFunctionWrapper.cs index d7f0964ad0..ce517e1f77 100644 --- a/src/Nncase.Core/IR/PrimFunctionWrapper.cs +++ b/src/Nncase.Core/IR/PrimFunctionWrapper.cs @@ -37,7 +37,7 @@ public PrimFunctionWrapper(string name, PrimFunction target, int parametersCount /// Target. /// Arguments count. public PrimFunctionWrapper(PrimFunction target, int parametersCount) - : this($"func_{_globalFuncIndex++}", target, parametersCount) + : this($"primfunc_wrapper_{_globalFuncIndex++}", target, parametersCount) { } diff --git a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs index 2d6486374f..7139dcb3a1 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs @@ -22,56 +22,21 @@ namespace Nncase.Passes.Rules.ShapeBucket; public class FusionShapeData { - public FusionShapeData(IValue outshape, IValue[] inputShapes) + public FusionShapeData(IValue outshape, IValue[] inputShapes, IValue?[] inputValues, bool[] inputFromShapes) { Outshape = outshape; InputShapes = inputShapes; + InputValues = inputValues; + InputFromShapes = inputFromShapes; } public IValue Outshape { get; } public IValue[] InputShapes { get; } -} - -public class FusionShapeUpdater : ExprVisitor -{ - private readonly Dictionary _memo; - - public FusionShapeUpdater(Dictionary memo) - { - _memo = memo; - } - public Dictionary FusionShape { get; } = new(); + public IValue?[] InputValues { get; } - protected override Expr DefaultVisitLeaf(Expr expr) => expr; - - protected override Expr VisitLeafCall(Call expr) - { - if (expr.Target is BucketFusion f) - { - var argShape = expr.Arguments.ToArray().Select(arg => - { - var exp = arg is Marker m ? m.Target : arg; - return GetShape(_memo[exp]); - }).ToArray(); - var shape = GetShape(_memo[expr]); - FusionShape[f] = new FusionShapeData(shape, argShape); - } - - return expr; - } - - private IValue GetShape(IValue value) - { - var shapes = value.AsTensors().Select(x => x.Shape.ToValueArray()).ToArray(); - if (shapes.Length == 1) - { - return Value.FromTensor(shapes[0]); - } - - return new TupleValue(shapes.Select(x => Value.FromTensor(x)).ToArray()); - } + public bool[] InputFromShapes { get; } } public class FusionShapeUpdater2 : ExprVisitor @@ -91,13 +56,14 @@ protected override Expr VisitLeafCall(Call expr) { if (expr.Target is BucketFusion f) { - var argShape = expr.Arguments.ToArray().Select(arg => + var argData = expr.Arguments.ToArray().Select(arg => { var exp = arg is Marker m ? m.Target : arg; - return GetValueOfShape(_memo[exp].IRType!); + var valueOrShape = _memo[exp]; + return (Shape: GetValueOfShape(valueOrShape.IRType!), Value: valueOrShape.Value, FromShape: valueOrShape.FromShape); }).ToArray(); var shape = GetValueOfShape(_memo[expr].IRType!); - FusionShape[f] = new FusionShapeData(shape, argShape); + FusionShape[f] = new FusionShapeData(shape, argData.Select(x => x.Shape).ToArray(), argData.Select(x => x.Value).ToArray(), argData.Select(x => x.FromShape).ToArray()); } return expr; @@ -239,13 +205,13 @@ protected override Task RunCoreAsync(BaseFunction main, RunPassCon var f = new FusionShapeUpdater(ConcatDictionary(memo, exprValues)); #else var input = MakeDummyInputType(varMap, varValues); - var eval = new PartialShapeEvaluator(input.ToDictionary(p => p.Key, p => new ValueOrShape(p.Value, null)), varValues); + var eval = new PartialShapeEvaluator(input.ToDictionary(p => p.Key, p => new ValueOrShape(p.Value, null, false)), varValues); eval.Visit(body); var memo = eval.ExprMemo; foreach (var (k, v) in exprValues) { var x = v.AsTensor(); - memo.Add(k, new(new TensorType(x.ElementType, x.Shape), v)); + memo.Add(k, new(new TensorType(x.ElementType, x.Shape), v, true)); } var f = new FusionShapeUpdater2(memo); @@ -270,7 +236,7 @@ public record ValueOrShape { private IValue? _concreteValue; - public ValueOrShape(IRType? irType, IValue? value) + public ValueOrShape(IRType? irType, IValue? value, bool fromShape) { if (irType is InvalidType) { @@ -280,6 +246,7 @@ public ValueOrShape(IRType? irType, IValue? value) IRType = irType; Value = value; _concreteValue = null; + FromShape = fromShape; } public IRType? IRType { get; } @@ -288,6 +255,8 @@ public ValueOrShape(IRType? irType, IValue? value) public bool HasValue => Value != null; + public bool FromShape { get; } + public IValue Concrete() { if (_concreteValue != null) @@ -325,9 +294,9 @@ public PartialShapeEvaluator(Dictionary inputDict, Dictionary protected override ValueOrShape VisitLeafMarker(Marker expr) => Visit(expr.Target); - protected override ValueOrShape VisitLeafBaseFunction(BaseFunction expr) => new(expr.CheckedType, null); + protected override ValueOrShape VisitLeafBaseFunction(BaseFunction expr) => new(expr.CheckedType, null, false); - protected override ValueOrShape VisitLeafOp(Op expr) => new(expr.CheckedType, null); + protected override ValueOrShape VisitLeafOp(Op expr) => new(expr.CheckedType, null, false); protected override ValueOrShape VisitLeafVar(Var expr) { @@ -337,7 +306,7 @@ protected override ValueOrShape VisitLeafVar(Var expr) } else if (DimDict.TryGetValue(expr, out var dimValue) && dimValue is TensorValue dimtv) { - return new(dimtv.Type, dimtv); + return new(dimtv.Type, dimtv, true); } else { @@ -349,8 +318,11 @@ protected override ValueOrShape VisitLeafVar(Var expr) protected override ValueOrShape VisitLeafTuple(IR.Tuple expr) { - var value = Value.FromTensors(expr.Fields.AsValueEnumerable().Select(Visit).Select(vs => vs.Concrete().AsTensor()).ToArray()); - return new(value.Type, value); + var valueOrShapes = expr.Fields.AsValueEnumerable().Select(Visit).ToArray(); + var value = Value.FromTensors(valueOrShapes.Select(vs => vs.Concrete().AsTensor()).ToArray()); + + // FIX ME: TupleType's from shape is not correct + return new(value.Type, value, valueOrShapes.Any(x => x.FromShape)); } protected override ValueOrShape VisitLeafCall(Call expr) @@ -363,7 +335,7 @@ protected override ValueOrShape VisitLeafCall(Call expr) { var shapeArr = ((TensorType)args[0].IRType!).Shape.Select(x => (long)x.FixedValue).ToArray(); var value = Value.FromTensor(Tensor.From(shapeArr)); - result = new(value.Type, value); + result = new(value.Type, value, true); } break; @@ -371,18 +343,19 @@ protected override ValueOrShape VisitLeafCall(Call expr) { if (args.All(x => x is { HasValue: true })) { + var fromShape = args.All(x => x.FromShape); var tmpCall = new Call(op, args.Select(a => Const.FromValue(a.Concrete())).ToArray()); var ctx = new EvaluateContext(args) { CurrentCall = tmpCall, }; var value = CompilerServices.EvaluateOp(op, ctx); - result = new(value.Type, value); + result = new(value.Type, value, fromShape); } else { var ctx = new TypeInferenceContext(args); - result = new(CompilerServices.InferenceOp(op, ctx, new()), null); + result = new(CompilerServices.InferenceOp(op, ctx, new()), null, false); } } @@ -408,7 +381,7 @@ protected override ValueOrShape VisitLeafCall(Call expr) return result; } - protected override ValueOrShape VisitLeafConst(Const expr) => new(expr.CheckedType, Value.FromConst(expr)); + protected override ValueOrShape VisitLeafConst(Const expr) => new(expr.CheckedType, Value.FromConst(expr), true); } internal sealed class EvaluateContext : IEvaluateContext diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs index 20825b1d7c..1ed3de41ca 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs @@ -943,15 +943,22 @@ public static Expr PreProcess(FusionBucketContext context, Var param, Dictionary var varInfo = shapeBucketOptions.RangeInfo.First().Value; var segments = ShapeBucketHelper.ComputeSegmentList(shapeBucketOptions.SegmentsCount, varInfo.Min, varInfo.Max); var segValue = segments[segIndex] - varInfo.Min; - var inputShapeInfo = context.ShapeInfos[segValue].InputShapes[inputIndex]; - var shape = inputShapeInfo.AsTensor().Cast(); - if (param.CheckedShape.IsFixed) + var inputShapeInfo = context.ShapeInfos[segValue]; + if (inputShapeInfo.InputFromShapes[inputIndex]) { - return param; + return inputShapeInfo.InputValues[inputIndex]!.AsTensor(); } else { - return new Call(new BucketPad(), param, shape); + var shape = inputShapeInfo.InputShapes[inputIndex].AsTensor().Cast(); + if (param.CheckedShape.IsFixed) + { + return param; + } + else + { + return new Call(new BucketPad(), param, shape); + } } } @@ -1033,7 +1040,25 @@ public static Function MakeSplitEntry(FusionBucketContext context, Dictionary varInfo, Var[] newVars, int segIndex) @@ -1255,27 +1280,7 @@ private static Expr MakeSlice(FusionBucketContext context, Expr call, Expr origi private static Expr MakeSliceForTensor(Expr sliceShape, Expr call, FusionBucketContext context) { - var slice = MakeSliceImpl(call, sliceShape); - var simplifySlice = CompilerServices.Rewrite( - slice, - new IRewriteRule[] - { - new FoldStackGetItem(), - new FoldShapeOf(), - new FoldTwoReshapes(), - new FoldTwoCasts(), - new FoldTwoSlices(), - new FoldNopBinary(), - new FoldNopCast(), - new Neutral.FoldConstCall(), - new FoldNopReshape(), - new FoldNopSlice(), - new FoldIf(), - new FoldBroadcastShape(), - }, - new()); - - return simplifySlice; + return MakeSliceImpl(call, sliceShape); } private static bool IsFixed(int totalCount, int[][] minFixedShapeList, int[][] maxFixedShapeList) @@ -1550,22 +1555,6 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo return Task.FromResult((BaseFunction)main.With(body: newBody)); } - private static FusionShapeData[] MakeShapeData((Var Key, int Value)[][] list, ShapeBucketOptions options) => - list.Select(seg => - { - var varValues = seg.ToDictionary(pair => pair.Key, pair => (IValue)Value.FromTensor(pair.Value)); - var inShape = options.VarMap.Select(pair => - { - var shapeExpr = pair.Key.CheckedShape.IsScalar - ? (Expr)Array.Empty() - : Stack(new IR.Tuple(pair.Value.Select(x => Cast(x, DataTypes.Int64)).ToArray()), 0); - - var shape = shapeExpr.Evaluate(varValues).AsTensor(); - return shape; - }).ToArray(); - return new FusionShapeData(Value.None, inShape.Select(Value.FromTensor).ToArray()); - }).ToArray(); - private static (Var Key, int Value)[][] InputConfList(Dictionary dimVarValues) => Enumerable.Range(0, dimVarValues.First().Value.Length).Select(i => { diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs index aa953733da..fe2cfd31ff 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs @@ -31,8 +31,8 @@ public static class CallValidator typeof(Transpose).TypeHandle, typeof(Pad).TypeHandle, - // typeof(Unsqueeze).TypeHandle, - // typeof(Squeeze).TypeHandle, + typeof(Unsqueeze).TypeHandle, + typeof(Squeeze).TypeHandle, typeof(Unary).TypeHandle, }; @@ -47,7 +47,7 @@ public static class CallValidator typeof(Cast).TypeHandle, - // typeof(Reshape).TypeHandle, + typeof(Reshape).TypeHandle, typeof(Expand).TypeHandle, typeof(ConstantOfShape).TypeHandle,