From 4155c66ccc007e4230030063f3e96eeea8b16232 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Sat, 28 Dec 2024 09:28:15 +0000 Subject: [PATCH] Fix shape bucket --- .../Rules/ShapeBucket/ShapeBucket.cs | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs index 5994aa8d3..2a4b48397 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs @@ -613,6 +613,7 @@ public FusionBucketContext(Call outerCall, BucketFusion fusion, ShapeBucketOptio Arguments = OuterCall.Arguments.ToArray(); Parameters = Fusion.Parameters.ToArray(); FixedShapeCache = new(); + ShapeInfos = shapeInfos; SliceShape = ComputeSliceShape(shapeInfos, options.RangeInfo.First().Value.Max); _index = index; } @@ -638,6 +639,8 @@ public FusionBucketContext(Call outerCall, BucketFusion fusion, ShapeBucketOptio // segIndex -> fixed shape list public Dictionary FixedShapeCache { get; } + public FusionShapeData[] ShapeInfos { get; } + public Expr FusionBody => Fusion.Body; public Dictionary DimVarValue(int i) => @@ -868,19 +871,13 @@ public FusionBucket(Dictionary list) public static Expr PreProcess(FusionBucketContext context, Var param, Dictionary inputInfo, Dictionary varValues, Dictionary fusionInputData, int segIndex, int inputIndex) { // Console.WriteLine($"seg index{segIndex}"); - if (context.FixedShapeCache.TryGetValue(segIndex, out var cachedFixedShape)) - { - // replace index by value - var shape = cachedFixedShape[inputIndex]; - if ((param.CheckedShape.IsFixed && shape.SequenceEqual(param.CheckedShape.ToValueArray())) || param.CheckedShape.IsScalar) - { - return param; - } - - return new Call(new BucketPad(), param, shape); - } - - throw new InvalidDataException("Shape Cache not found"); + var shapeBucketOptions = CompileSessionScope.Current!.CompileOptions.ShapeBucketOptions; + var varInfo = shapeBucketOptions.RangeInfo.First().Value; + var segments = ShapeBucketHelper.ComputeSegmentList(shapeBucketOptions.SegmentsCount, varInfo.Min, varInfo.Max); + var segValue = segments[segIndex] - varInfo.Min; + var inputShapeInfo = context.ShapeInfos[segValue].InputShapes[inputIndex]; + var shape = inputShapeInfo.AsTensor().Cast(); + return new Call(new BucketPad(), param, shape); } public static (Dictionary MinDict, Dictionary MaxDict) GetBoundDict(