Skip to content

Commit

Permalink
Fix shape bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Dec 28, 2024
1 parent e473a9c commit 4155c66
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ public FusionBucketContext(Call outerCall, BucketFusion fusion, ShapeBucketOptio
Arguments = OuterCall.Arguments.ToArray();
Parameters = Fusion.Parameters.ToArray();
FixedShapeCache = new();
ShapeInfos = shapeInfos;
SliceShape = ComputeSliceShape(shapeInfos, options.RangeInfo.First().Value.Max);
_index = index;
}
Expand All @@ -638,6 +639,8 @@ public FusionBucketContext(Call outerCall, BucketFusion fusion, ShapeBucketOptio
// segIndex -> fixed shape list
public Dictionary<int, int[][]> FixedShapeCache { get; }

public FusionShapeData[] ShapeInfos { get; }

public Expr FusionBody => Fusion.Body;

public Dictionary<Var, IValue> DimVarValue(int i) =>
Expand Down Expand Up @@ -868,19 +871,13 @@ public FusionBucket(Dictionary<BucketFusion, FusionShapeData[]> list)
public static Expr PreProcess(FusionBucketContext context, Var param, Dictionary<Var, Expr[]> inputInfo, Dictionary<Var, IValue> varValues, Dictionary<Var, Expr[]> fusionInputData, int segIndex, int inputIndex)
{
// Console.WriteLine($"seg index{segIndex}");
if (context.FixedShapeCache.TryGetValue(segIndex, out var cachedFixedShape))
{
// replace index by value
var shape = cachedFixedShape[inputIndex];
if ((param.CheckedShape.IsFixed && shape.SequenceEqual(param.CheckedShape.ToValueArray())) || param.CheckedShape.IsScalar)
{
return param;
}

return new Call(new BucketPad(), param, shape);
}

throw new InvalidDataException("Shape Cache not found");
var shapeBucketOptions = CompileSessionScope.Current!.CompileOptions.ShapeBucketOptions;
var varInfo = shapeBucketOptions.RangeInfo.First().Value;
var segments = ShapeBucketHelper.ComputeSegmentList(shapeBucketOptions.SegmentsCount, varInfo.Min, varInfo.Max);
var segValue = segments[segIndex] - varInfo.Min;
var inputShapeInfo = context.ShapeInfos[segValue].InputShapes[inputIndex];
var shape = inputShapeInfo.AsTensor().Cast<long>();
return new Call(new BucketPad(), param, shape);
}

public static (Dictionary<Var, IValue> MinDict, Dictionary<Var, IValue> MaxDict) GetBoundDict(
Expand Down

0 comments on commit 4155c66

Please sign in to comment.