Skip to content

Commit

Permalink
fix tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 14, 2025
1 parent 997c8eb commit 43f9b24
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
14 changes: 8 additions & 6 deletions src/Nncase.Schedule/Schedule/GraphTiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,30 @@ public sealed class GraphTiler

private int _useCached;

public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetOptions targetOptions)
public int DeviceFuncionCount { get; set; }

public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOptions)
{
var topLevel = targetOptions.MemoryCapacities.Length;
var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo);
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
rootGraph.Dump($"device_func{itemNumber}_original");
rootGraph.Dump($"tile_graph");
}

// bufferize root graph.
var bufferGraphMemo = rootGraph.Bufferize();
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
bufferGraphMemo[rootGraph].Dump($"device_func{itemNumber}_original_buffer");
bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph");
}

#if true
// condense the root graph.
var condensedGraph = rootGraph.Condense();
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
using (var file = Diagnostics.DumpScope.Current.OpenFile($"device_func{itemNumber}_condensed.dot"))
using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot"))
{
using var writer = new StreamWriter(file);
writer.Write(condensedGraph.ToGraphviz(init =>
Expand All @@ -64,7 +66,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO
var resultMemo = new Dictionary<TieredTileGraph, Expr>();
foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i)))
{
using var subscope = new Diagnostics.DumpScope($"device_func{itemNumber}_{i}", Diagnostics.DumpFlags.Tiling);
using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling);
var primTree = treeGraphMemo[primGraph];
var primBufferGraph = bufferGraphMemo[primGraph];
HashSet<BufferIdentity> inputBids;
Expand All @@ -78,7 +80,7 @@ public Expr Tile(Expr preExpr, string moduleKind, string itemNumber, ICpuTargetO
var bodyBuilder = T.Sequential();
result.Visit(primTree, new(bodyBuilder, Array.Empty<Expr>()));
var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray();
var funcBuilder = T.PrimFunc($"device_func{itemNumber}_{i}", moduleKind, parameters).Body(bodyBuilder);
var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder);
var primFunc = funcBuilder.Build();
wrapper = new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray());
_primFuncMemo.Add(primTree, wrapper);
Expand Down
68 changes: 58 additions & 10 deletions src/Nncase.Schedule/Transforms/AutoTilePass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ protected override Task<IRModule> RunCoreAsync(IRModule input, RunPassContext co
var funcNums = input.Functions.Count;
for (int i = 0; i < funcNums; i++)
{
var post = Rewrite(input.Functions[i], i, tiler);
var pre = input.Functions[i];
using var scope = new Diagnostics.DumpScope(pre.Name);
var post = Rewrite(pre, tiler);
input.Replace(i, post);
}

return Task.FromResult(input);
}

private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler)
private BaseFunction Rewrite(BaseFunction pre, GraphTiler tiler)
{
if (!(pre is IR.Fusion fusion && fusion.ModuleKind == ModuleKind))
{
Expand Down Expand Up @@ -82,8 +84,8 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler)

if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Rewrite))
{
condenseAlgo.CondensedGraph.Dump($"{funcName}Condensed", init => { });
condenseAlgo.ClusteredGraph.Dump($"{funcName}Cluster", algo =>
condenseAlgo.CondensedGraph.Dump($"Condensed", init => { });
condenseAlgo.ClusteredGraph.Dump($"Cluster", algo =>
{
algo.FormatVertex += (s, arg) =>
{
Expand All @@ -93,7 +95,7 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler)
}

// 3. reconstruction
var constructor = new AutoTileReConstructor(tiler, funcNumber, ModuleKind, CompileOptions, condenseAlgo);
var constructor = new AutoTileReConstructor(tiler, ModuleKind, CompileOptions, condenseAlgo);
var post = constructor.Construct();
return fusion.With(fusion.Name, fusion.ModuleKind, post, fusion.Parameters.ToArray());
}
Expand Down Expand Up @@ -129,25 +131,71 @@ protected override ExprVertex VisitLeafGrid(Grid expr, IMutableVertexAndEdgeList

internal sealed class AutoTileReConstructor : ExprReConstructor<ExprVertex, ExprEdge>
{
public AutoTileReConstructor(GraphTiler tiler, int funcNumber, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm<ExprVertex, ExprEdge> algo)
public AutoTileReConstructor(GraphTiler tiler, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm<ExprVertex, ExprEdge> algo)
: base(algo)
{
Tiler = tiler;
FuncNumber = funcNumber;
ModuleKind = moduleKind;
CompileOptions = compileOptions;
}

public GraphTiler Tiler { get; }

public int FuncNumber { get; }

public string ModuleKind { get; }

public CompileOptions CompileOptions { get; }

protected override Expr OnAtomCluster(ClusteredBidirectionalGraph<ExprVertex, ExprEdge> cluster, int sortIndex)
{
using var subscope = new Diagnostics.DumpScope($"cluster_{sortIndex}", Diagnostics.DumpFlags.Tiling);
var pairs = GetClusterArgumentPairs(cluster);
var vertex = cluster.Vertices.First();
var expr = vertex.Expr;
if (expr is Grid)
{
var extractDict = new Dictionary<Expr, Expr>(ReferenceEqualityComparer.Instance);
var argumentDict = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
foreach (var (pre, post) in pairs)
{
if (pre is Const)
{
continue;
}

var @var = new Var(pre.CheckedType);
var added = extractDict.TryAdd(pre, @var);
if (added)
{
argumentDict.Add(@var, post);
}
}

var cloner = new ExprClusterCloner(extractDict);
Expr cloned = cloner.Clone(expr, default);
var tiled = Tiler.Tile(cloned, ModuleKind, (ICpuTargetOptions)CompileOptions.TargetOptions);
var substitutor = new Mutators.Substitutor(e =>
{
if (e is Var v && argumentDict.TryGetValue(v, out var arg))
{
return arg;
}

return null;
});

var substited = substitutor.Rewrite(tiled, default);
return substited;
}
else
{
var cloner = new ExprClusterCloner(pairs.ToDictionary(p => p.Pre, p => p.Post, new ReferenceEqualityComparer<Expr>()));
return cloner.Clone(expr, default);
}
}

protected override Expr OnComplexCluster(ClusteredBidirectionalGraph<ExprVertex, ExprEdge> cluster, int sortIndex)
{
using var subscope = new Diagnostics.DumpScope($"cluster_{sortIndex}", Diagnostics.DumpFlags.Tiling);
var pairs = GetClusterArgumentPairs(cluster);
var extractDict = new Dictionary<Expr, Expr>(ReferenceEqualityComparer.Instance);
var argumentDict = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
Expand Down Expand Up @@ -175,7 +223,7 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph<ExprVertex,
}

Expr cloned = clones.Count == 1 ? clones[0] : new IR.Tuple(clones.ToArray());
var tiled = Tiler.Tile(cloned, ModuleKind, $"{FuncNumber}_{sortIndex}", (ICpuTargetOptions)CompileOptions.TargetOptions);
var tiled = Tiler.Tile(cloned, ModuleKind, (ICpuTargetOptions)CompileOptions.TargetOptions);

var substitutor = new Mutators.Substitutor(e =>
{
Expand Down
3 changes: 2 additions & 1 deletion src/Nncase.Tests/Affine/UnitTestTileGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ public void TestSolveTileGraph(Func<Function> functor, Action<Expr> action, int
var post = CompilerServices.Rewrite(func, new IRewriteRule[] { new Passes.Rules.CPU.Affine.LowerPack(), new Passes.Rules.CPU.Affine.LowerUnary(), new Passes.Rules.CPU.Affine.LowerMatmul(), new Passes.Rules.CPU.Affine.LowerBinary() }, new());

var tiler = new Schedule.GraphTiler();
var result = tiler.Tile(post, Nncase.Targets.CPUTarget.Kind, count.ToString(), (ICpuTargetOptions)CompileOptions.TargetOptions);
using var scope = new Diagnostics.DumpScope($"{count}");
var result = tiler.Tile(post, Nncase.Targets.CPUTarget.Kind, (ICpuTargetOptions)CompileOptions.TargetOptions);
action(result);
}

Expand Down

0 comments on commit 43f9b24

Please sign in to comment.