diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs index 4777aaec9..c87784529 100644 --- a/src/Nncase.Schedule/Schedule/GraphTiler.cs +++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs @@ -20,20 +20,22 @@ public sealed class GraphTiler private int _useCached; - public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetOptions targetOptions) + public int DeviceFuncionCount { get; set; } + + public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOptions) { var topLevel = targetOptions.MemoryCapacities.Length; var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) { - rootGraph.Dump($"device_func{itemNumber}_original"); + rootGraph.Dump($"tile_graph"); } // bufferize root graph. var bufferGraphMemo = rootGraph.Bufferize(); if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) { - bufferGraphMemo[rootGraph].Dump($"device_func{itemNumber}_original_buffer"); + bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph"); } #if true @@ -41,7 +43,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO var condensedGraph = rootGraph.Condense(); if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) { - using (var file = Diagnostics.DumpScope.Current.OpenFile($"device_func{itemNumber}_condensed.dot")) + using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot")) { using var writer = new StreamWriter(file); writer.Write(condensedGraph.ToGraphviz(init => @@ -64,7 +66,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO var resultMemo = new Dictionary(); foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i))) { - using var subscope = new Diagnostics.DumpScope($"device_func{itemNumber}_{i}", Diagnostics.DumpFlags.Tiling); + using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling); var primTree = treeGraphMemo[primGraph]; var primBufferGraph = bufferGraphMemo[primGraph]; HashSet inputBids; @@ -78,7 +80,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO var bodyBuilder = T.Sequential(); result.Visit(primTree, new(bodyBuilder, Array.Empty())); var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray(); - var funcBuilder = T.PrimFunc($"device_func{itemNumber}_{i}", moduleKind, parameters).Body(bodyBuilder); + var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder); var primFunc = funcBuilder.Build(); wrapper = new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()); _primFuncMemo.Add(primTree, wrapper); diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index 9c51ca04f..120667a65 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -38,14 +38,16 @@ protected override Task RunCoreAsync(IRModule input, RunPassContext co var funcNums = input.Functions.Count; for (int i = 0; i < funcNums; i++) { - var post = Rewrite(input.Functions[i], i, tiler); + var pre = input.Functions[i]; + using var scope = new Diagnostics.DumpScope(pre.Name); + var post = Rewrite(pre, tiler); input.Replace(i, post); } return Task.FromResult(input); } - private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler) + private BaseFunction Rewrite(BaseFunction pre, GraphTiler tiler) { if (!(pre is IR.Fusion fusion && fusion.ModuleKind == ModuleKind)) { @@ -82,8 +84,8 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler) if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Rewrite)) { - condenseAlgo.CondensedGraph.Dump($"{funcName}Condensed", init => { }); - condenseAlgo.ClusteredGraph.Dump($"{funcName}Cluster", algo => + condenseAlgo.CondensedGraph.Dump($"Condensed", init => { }); + condenseAlgo.ClusteredGraph.Dump($"Cluster", algo => { algo.FormatVertex += (s, arg) => { @@ -93,7 +95,7 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler) } // 3. reconstruction - var constructor = new AutoTileReConstructor(tiler, funcNumber, ModuleKind, CompileOptions, condenseAlgo); + var constructor = new AutoTileReConstructor(tiler, ModuleKind, CompileOptions, condenseAlgo); var post = constructor.Construct(); return fusion.With(fusion.Name, fusion.ModuleKind, post, fusion.Parameters.ToArray()); } @@ -129,25 +131,71 @@ protected override ExprVertex VisitLeafGrid(Grid expr, IMutableVertexAndEdgeList internal sealed class AutoTileReConstructor : ExprReConstructor { - public AutoTileReConstructor(GraphTiler tiler, int funcNumber, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm algo) + public AutoTileReConstructor(GraphTiler tiler, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm algo) : base(algo) { Tiler = tiler; - FuncNumber = funcNumber; ModuleKind = moduleKind; CompileOptions = compileOptions; } public GraphTiler Tiler { get; } - public int FuncNumber { get; } - public string ModuleKind { get; } public CompileOptions CompileOptions { get; } + protected override Expr OnAtomCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + using var subscope = new Diagnostics.DumpScope($"cluster_{sortIndex}", Diagnostics.DumpFlags.Tiling); + var pairs = GetClusterArgumentPairs(cluster); + var vertex = cluster.Vertices.First(); + var expr = vertex.Expr; + if (expr is Grid) + { + var extractDict = new Dictionary(ReferenceEqualityComparer.Instance); + var argumentDict = new Dictionary(ReferenceEqualityComparer.Instance); + foreach (var (pre, post) in pairs) + { + if (pre is Const) + { + continue; + } + + var @var = new Var(pre.CheckedType); + var added = extractDict.TryAdd(pre, @var); + if (added) + { + argumentDict.Add(@var, post); + } + } + + var cloner = new ExprClusterCloner(extractDict); + Expr cloned = cloner.Clone(expr, default); + var tiled = Tiler.Tile(cloned, ModuleKind, (ICpuTargetOptions)CompileOptions.TargetOptions); + var substitutor = new Mutators.Substitutor(e => + { + if (e is Var v && argumentDict.TryGetValue(v, out var arg)) + { + return arg; + } + + return null; + }); + + var substited = substitutor.Rewrite(tiled, default); + return substited; + } + else + { + var cloner = new ExprClusterCloner(pairs.ToDictionary(p => p.Pre, p => p.Post, new ReferenceEqualityComparer())); + return cloner.Clone(expr, default); + } + } + protected override Expr OnComplexCluster(ClusteredBidirectionalGraph cluster, int sortIndex) { + using var subscope = new Diagnostics.DumpScope($"cluster_{sortIndex}", Diagnostics.DumpFlags.Tiling); var pairs = GetClusterArgumentPairs(cluster); var extractDict = new Dictionary(ReferenceEqualityComparer.Instance); var argumentDict = new Dictionary(ReferenceEqualityComparer.Instance); @@ -175,7 +223,7 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph { diff --git a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs index 830712338..5188b2257 100644 --- a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs +++ b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs @@ -304,7 +304,8 @@ public void TestSolveTileGraph(Func functor, Action action, int var post = CompilerServices.Rewrite(func, new IRewriteRule[] { new Passes.Rules.CPU.Affine.LowerPack(), new Passes.Rules.CPU.Affine.LowerUnary(), new Passes.Rules.CPU.Affine.LowerMatmul(), new Passes.Rules.CPU.Affine.LowerBinary() }, new()); var tiler = new Schedule.GraphTiler(); - var result = tiler.Tile(post, Nncase.Targets.CPUTarget.Kind, count.ToString(), (ICpuTargetOptions)CompileOptions.TargetOptions); + using var scope = new Diagnostics.DumpScope($"{count}"); + var result = tiler.Tile(post, Nncase.Targets.CPUTarget.Kind, (ICpuTargetOptions)CompileOptions.TargetOptions); action(result); }