-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) Canaan Inc. All rights reserved. | ||
// Licensed under the Apache license. See LICENSE file in the project root for full license information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Reactive; | ||
using Nncase.IR; | ||
using Nncase.IR.Math; | ||
using Nncase.IR.NN; | ||
using Nncase.PatternMatch; | ||
using static Nncase.IR.TypePatternUtility; | ||
using static Nncase.PatternMatch.Utility; | ||
|
||
namespace Nncase.Passes.Rules.Neutral; | ||
|
||
[RuleGenerator] | ||
public sealed partial class LiftCEInIf : RewriteRule<Pattern> | ||
{ | ||
public override Pattern Pattern => IsWildcard("expr", expr => expr is If); | ||
|
||
private Expr? GetReplace(If expr) | ||
{ | ||
var parameters = expr.ParamList.ToList(); | ||
var thenExprs = LiftCollector.Collect(expr.Then).ToHashSet(ReferenceEqualityComparer.Instance); | ||
var elseExprs = LiftCollector.Collect(expr.Else).ToHashSet(ReferenceEqualityComparer.Instance); | ||
var commonExprs = thenExprs.Intersect(elseExprs).Where(x => !(x is Var or If)).Except(parameters).Cast<Expr>().ToArray(); | ||
if (commonExprs.Any()) | ||
{ | ||
parameters.AddRange(commonExprs); | ||
return new If(expr.Condition, expr.Then, expr.Else, parameters.ToArray()); | ||
} | ||
else | ||
{ | ||
return null; | ||
} | ||
} | ||
|
||
public sealed class LiftCollector : ExprWalker<List<Expr>> | ||
{ | ||
private LiftCollector() | ||
{ | ||
} | ||
|
||
public static IReadOnlyList<Expr> Collect(Expr expr) | ||
{ | ||
var exprs = new List<Expr>(); | ||
new LiftCollector().Visit(expr, exprs); | ||
return exprs; | ||
} | ||
|
||
// protected override Unit VisitIf(If expr, List<Expr> context) | ||
// { | ||
// foreach (var param in expr.ParamList) | ||
// { | ||
// Visit(param, context); | ||
// } | ||
|
||
// Visit(expr.Condition, context); | ||
Check failure on line 59 in src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs GitHub Actions / build-x86_64-linux
Check failure on line 59 in src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs GitHub Actions / build-aarch64-macos
Check failure on line 59 in src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs GitHub Actions / build-x86_64-linux
|
||
// return default; | ||
// } | ||
Check failure on line 61 in src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs GitHub Actions / build-x86_64-linux
Check failure on line 61 in src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs GitHub Actions / build-aarch64-macos
Check failure on line 61 in src/Nncase.Passes/Rules/Neutral/LiftCEInIf.cs GitHub Actions / build-x86_64-linux
|
||
|
||
protected override Unit DefaultVisitLeaf(Expr expr, List<Expr> context) | ||
{ | ||
context.Add(expr); | ||
return default; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,34 +28,34 @@ public static class CallValidator | |
typeof(Conv2D).TypeHandle, | ||
typeof(Conv2DTranspose).TypeHandle, | ||
typeof(MatMul).TypeHandle, | ||
typeof(Transpose).TypeHandle, | ||
typeof(Pad).TypeHandle, | ||
typeof(Unsqueeze).TypeHandle, | ||
typeof(Squeeze).TypeHandle, | ||
typeof(Unary).TypeHandle, | ||
//typeof(Transpose).TypeHandle, | ||
Check failure on line 31 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
Check failure on line 31 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
Check failure on line 31 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-aarch64-macos
Check failure on line 31 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-aarch64-macos
Check failure on line 31 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
Check failure on line 31 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
|
||
//typeof(Pad).TypeHandle, | ||
Check failure on line 32 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
Check failure on line 32 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-aarch64-macos
Check failure on line 32 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
|
||
//typeof(Unsqueeze).TypeHandle, | ||
Check failure on line 33 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
Check failure on line 33 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-aarch64-macos
Check failure on line 33 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
|
||
//typeof(Squeeze).TypeHandle, | ||
//typeof(Unary).TypeHandle, | ||
}; | ||
|
||
private static readonly HashSet<RuntimeTypeHandle> MaybeDynamic = new() | ||
{ | ||
typeof(Concat).TypeHandle, | ||
typeof(Stack).TypeHandle, | ||
typeof(Binary).TypeHandle, | ||
typeof(Slice).TypeHandle, | ||
typeof(Gather).TypeHandle, | ||
typeof(ShapeOf).TypeHandle, | ||
// typeof(Concat).TypeHandle, | ||
// typeof(Stack).TypeHandle, | ||
// typeof(Binary).TypeHandle, | ||
// typeof(Slice).TypeHandle, | ||
// typeof(Gather).TypeHandle, | ||
// typeof(ShapeOf).TypeHandle, | ||
Check failure on line 45 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
Check failure on line 45 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-aarch64-macos
Check failure on line 45 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs GitHub Actions / build-x86_64-linux
|
||
|
||
typeof(Cast).TypeHandle, | ||
|
||
typeof(Reshape).TypeHandle, | ||
typeof(Expand).TypeHandle, | ||
typeof(ConstantOfShape).TypeHandle, | ||
// typeof(Reshape).TypeHandle, | ||
// typeof(Expand).TypeHandle, | ||
// typeof(ConstantOfShape).TypeHandle, | ||
|
||
// typeof(Where).TypeHandle, | ||
typeof(Compare).TypeHandle, | ||
typeof(Reduce).TypeHandle, | ||
typeof(Clamp).TypeHandle, | ||
typeof(Tile).TypeHandle, | ||
typeof(CumSum).TypeHandle, | ||
// typeof(Compare).TypeHandle, | ||
// typeof(Reduce).TypeHandle, | ||
// typeof(Clamp).TypeHandle, | ||
// typeof(Tile).TypeHandle, | ||
// typeof(CumSum).TypeHandle, | ||
}; | ||
|
||
public static bool IsMaybeDynamic(Expr target) => MaybeDynamic.Contains(target.GetType().TypeHandle); | ||
|