From 625bfa1f8f5986f421426b86a13186a434413b87 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Sat, 28 Dec 2024 14:26:32 +0000 Subject: [PATCH] Disable some buckets --- src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs | 69 +++++++++++++++++++ .../Rules/ShapeBucket/ShapeBucketHelper.cs | 38 +++++----- 2 files changed, 88 insertions(+), 19 deletions(-) create mode 100644 src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs diff --git a/src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs b/src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs new file mode 100644 index 000000000..586480167 --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs @@ -0,0 +1,69 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public sealed partial class LiftCEInIf : RewriteRule +{ + public override Pattern Pattern => IsWildcard("expr", expr => expr is If); + + private Expr? GetReplace(If expr) + { + var parameters = expr.ParamList.ToList(); + var thenExprs = LiftCollector.Collect(expr.Then).ToHashSet(ReferenceEqualityComparer.Instance); + var elseExprs = LiftCollector.Collect(expr.Else).ToHashSet(ReferenceEqualityComparer.Instance); + var commonExprs = thenExprs.Intersect(elseExprs).Where(x => !(x is Var or If)).Except(parameters).Cast().ToArray(); + if (commonExprs.Any()) + { + parameters.AddRange(commonExprs); + return new If(expr.Condition, expr.Then, expr.Else, parameters.ToArray()); + } + else + { + return null; + } + } + + public sealed class LiftCollector : ExprWalker> + { + private LiftCollector() + { + } + + public static IReadOnlyList Collect(Expr expr) + { + var exprs = new List(); + new LiftCollector().Visit(expr, exprs); + return exprs; + } + + // protected override Unit VisitIf(If expr, List context) + // { + // foreach (var param in expr.ParamList) + // { + // Visit(param, context); + // } + + // Visit(expr.Condition, context); + // return default; + // } + + protected override Unit DefaultVisitLeaf(Expr expr, List context) + { + context.Add(expr); + return default; + } + } +} diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs index 1d39724ce..3bbe83561 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs @@ -28,34 +28,34 @@ public static class CallValidator typeof(Conv2D).TypeHandle, typeof(Conv2DTranspose).TypeHandle, typeof(MatMul).TypeHandle, - typeof(Transpose).TypeHandle, - typeof(Pad).TypeHandle, - typeof(Unsqueeze).TypeHandle, - typeof(Squeeze).TypeHandle, - typeof(Unary).TypeHandle, + //typeof(Transpose).TypeHandle, + //typeof(Pad).TypeHandle, + //typeof(Unsqueeze).TypeHandle, + //typeof(Squeeze).TypeHandle, + //typeof(Unary).TypeHandle, }; private static readonly HashSet MaybeDynamic = new() { - typeof(Concat).TypeHandle, - typeof(Stack).TypeHandle, - typeof(Binary).TypeHandle, - typeof(Slice).TypeHandle, - typeof(Gather).TypeHandle, - typeof(ShapeOf).TypeHandle, + // typeof(Concat).TypeHandle, + // typeof(Stack).TypeHandle, + // typeof(Binary).TypeHandle, + // typeof(Slice).TypeHandle, + // typeof(Gather).TypeHandle, + // typeof(ShapeOf).TypeHandle, typeof(Cast).TypeHandle, - typeof(Reshape).TypeHandle, - typeof(Expand).TypeHandle, - typeof(ConstantOfShape).TypeHandle, + // typeof(Reshape).TypeHandle, + // typeof(Expand).TypeHandle, + // typeof(ConstantOfShape).TypeHandle, // typeof(Where).TypeHandle, - typeof(Compare).TypeHandle, - typeof(Reduce).TypeHandle, - typeof(Clamp).TypeHandle, - typeof(Tile).TypeHandle, - typeof(CumSum).TypeHandle, + // typeof(Compare).TypeHandle, + // typeof(Reduce).TypeHandle, + // typeof(Clamp).TypeHandle, + // typeof(Tile).TypeHandle, + // typeof(CumSum).TypeHandle, }; public static bool IsMaybeDynamic(Expr target) => MaybeDynamic.Contains(target.GetType().TypeHandle);