diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml index 64bf23c36..2270dabec 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml @@ -8,10 +8,10 @@ #include #include "topo_aware_runtime.h" -#include "../device.h" -@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).Skip(1).SkipLast(1)){ -@:uint8_t L@(i)Data[@(s)]; +@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).SkipLast(1)){ +@:uint8_t L@(i+1)Data[@(s)]; } +#include "../device.h" #include "kernel.h" //alignas(@(Model.Alignment)) static thread_local uint8_t local_data[@(Model.DataSize)]; diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs index cd7e9bd44..6e8f329f0 100644 --- a/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs +++ b/modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs @@ -4,6 +4,7 @@ namespace Nncase.TIR.CPU; +[ParameterInPlace(0, 1)] public sealed partial class Unary : CPUKernelOp { public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input"); diff --git a/src/Nncase.Core/IR/Op.cs b/src/Nncase.Core/IR/Op.cs index 2af2e39fd..92a18a87a 100644 --- a/src/Nncase.Core/IR/Op.cs +++ b/src/Nncase.Core/IR/Op.cs @@ -18,6 +18,20 @@ public enum ParameterKind : int Attribute, } +[AttributeUsage(System.AttributeTargets.Class, Inherited = false, AllowMultiple = true)] +public sealed class ParameterInPlaceAttribute : System.Attribute +{ + public ParameterInPlaceAttribute(int sourceIndex, int destIndex) + { + SourceIndex = sourceIndex; + DestIndex = destIndex; + } + + public int SourceIndex { get; } + + public int DestIndex { get; } +} + /// /// Parameter information. /// diff --git a/src/Nncase.Core/TIR/TIRExtensions.cs b/src/Nncase.Core/TIR/TIRExtensions.cs index 03e0a2cf6..5bec0e126 100644 --- a/src/Nncase.Core/TIR/TIRExtensions.cs +++ b/src/Nncase.Core/TIR/TIRExtensions.cs @@ -16,6 +16,15 @@ namespace Nncase.TIR; /// public static class TIRExtensions { + /// + /// Get the tir op buffer allocation reuse information. + /// + /// map dest index to source index. + public static Dictionary GetInPlaceMemo(this Op op) + { + return op.GetType().GetCustomAttributes(typeof(ParameterInPlaceAttribute), true).OfType().ToDictionary(a => a.DestIndex, a => a.SourceIndex); + } + /// /// convert IEnumerable to tir Sequential. /// diff --git a/src/Nncase.Graph/Graphs/CondensationTieredGraphAlgorithm.cs b/src/Nncase.Graph/Graphs/CondensationTieredGraphAlgorithm.cs index 64aac9850..f13ef58e4 100644 --- a/src/Nncase.Graph/Graphs/CondensationTieredGraphAlgorithm.cs +++ b/src/Nncase.Graph/Graphs/CondensationTieredGraphAlgorithm.cs @@ -52,6 +52,7 @@ protected override void InternalCompute() var dfs = new DepthFirstSearchAlgorithm(this, VisitedGraph, new Dictionary(VisitedGraph.VertexCount)); dfs.TreeEdge += TreeEdge; + dfs.ForwardOrCrossEdge += TreeEdge; dfs.Compute(); } diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs index c87784529..4978852e4 100644 --- a/src/Nncase.Schedule/Schedule/GraphTiler.cs +++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs @@ -14,15 +14,38 @@ namespace Nncase.Schedule; -public sealed class GraphTiler +public static class GraphTiler { - private readonly Dictionary _primFuncMemo = new(new ITreeNodeComparer()); + public static Expr MCTSTiling(Expr preExpr, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) + { + var topLevel = targetOptions.MemoryCapacities.Length; + var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); + var rootState = new MCTState(rootGraph, moduleKind, prefix, "0", solveMemo, targetOptions); + var rootNode = new MCTNode(rootState); + var searcher = new MCTSearcher(); + searcher.Search(rootNode); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + rootNode.Dump("SearchTree"); + } + + var bestState = (MCTState)searcher.BestMCTNode!.State; + var replaces = new Dictionary(); + foreach (var (oldExpr, v) in exprMemo) + { + if (bestState.Results.TryGetValue(v, out var newExpr)) + { + replaces.Add(oldExpr, newExpr); + } + } - private int _useCached; + var cloner = new ReplacingExprCloner(replaces); + return cloner.Clone(preExpr, default); + } public int DeviceFuncionCount { get; set; } - public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOptions) + public Expr Tile(Expr preExpr, string moduleKind, Dictionary solveMemo, ICpuTargetOptions targetOptions) { var topLevel = targetOptions.MemoryCapacities.Length; var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); @@ -31,6 +54,13 @@ public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOption rootGraph.Dump($"tile_graph"); } + var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, prefix, solveMemo, targetOptions); + var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value])); + return cloner.Clone(preExpr, default); + } + + public static (Dictionary ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) + { // bufferize root graph. var bufferGraphMemo = rootGraph.Bufferize(); if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) @@ -38,7 +68,6 @@ public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOption bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph"); } -#if true // condense the root graph. var condensedGraph = rootGraph.Condense(); if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) @@ -64,17 +93,17 @@ public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOption var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index)); var resultMemo = new Dictionary(); + long objectValue = 0; foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i))) { using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling); var primTree = treeGraphMemo[primGraph]; - var primBufferGraph = bufferGraphMemo[primGraph]; HashSet inputBids; HashSet outputBids; - if (!_primFuncMemo.TryGetValue(primTree, out var wrapper)) + if (!solveMemo.TryGetValue(primTree, out var memo)) { - var result = SolvePrimGraph(primTree, primBufferGraph, targetOptions, moduleKind); + var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind); (inputBids, outputBids) = (result.Inputs, result.Outputs); result.ScheduleBuffers(); var bodyBuilder = T.Sequential(); @@ -82,16 +111,16 @@ public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOption var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray(); var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder); var primFunc = funcBuilder.Build(); - wrapper = new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()); - _primFuncMemo.Add(primTree, wrapper); + memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue); + solveMemo.Add(primTree, memo); } else { - (inputBids, outputBids) = primBufferGraph.GetInputsOutputs(); - _useCached++; + (inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs(); } - var finalCall = new Call(wrapper, inputBids.Select(bid => argumentsMemo[bid]).ToArray()); + objectValue += memo.ObjectValue; + var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray()); resultMemo.Add(primGraph, finalCall); // save the output. @@ -109,79 +138,34 @@ public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOption } } - var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value])); - return cloner.Clone(preExpr, default); - -#else - PrimGraphSolveResult? bestConstructor = null; - foreach (var chunk in EnumerateAll(root, totalLevel, new()).Chunk(System.Math.Max(System.Environment.ProcessorCount - 2, 1))) - { - foreach (var resultConstructor in chunk.AsParallel().Select(isoTree => Solve(isoTree.Root, targetOptions)).OfType()) - { - bestConstructor = (bestConstructor?.ObjectiveValue <= resultConstructor.ObjectiveValue ? bestConstructor : resultConstructor) ?? resultConstructor; - } - } -#endif - - // if (bestConstructor is null) - // { - // throw new InvalidOperationException("can't solver!"); - // } - - // if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - // { - // bestConstructor.Tree.Dump($"device_func{itemNumber}_best"); - // } - - // return bestConstructor.ConstructResult(moduleKind, itemNumber); - // return new Call(None.Default); + return (resultMemo, objectValue); } - private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBufferGraph, ICpuTargetOptions targetOptions, string moduleKind) + public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind) { int[] memoryCapacities = targetOptions.MemoryCapacities; int[] memoryBandWidths = targetOptions.MemoryBandWidths; var topLevel = memoryCapacities.Length; - TreeSolverInitializer.Init(primTree, topLevel, targetOptions, out var solver, out var opNodeMemo, out var tileNodeMemo, out var tileableNodeMemo); + TreeSolverInitializer.Init(primTree, bufferGraphMemo, topLevel, targetOptions, out var solver, out var opNodeMemo, out var tileNodeMemo, out var tileableNodeMemo); // 0. the top level already store a buffer at outter most. var toplevelStoreBufferConstraints = new List(); + var (inputBids, outputBids) = bufferGraphMemo[primTree.Wrapped].GetInputsOutputs(); foreach (var (bid, binfo) in tileNodeMemo[primTree].BufferInfoMap) { - var cons = solver.MakeEquality(binfo.Places[0][^1], 1); - cons.SetName($"{bid}StoreAtOutMost"); - solver.Add(cons); - toplevelStoreBufferConstraints.Add(cons); - } - - // 0.1 parent node's inner place is equal to child's outter place. - var duplictePlaceConstranits = new List(); - primTree.Walk( - (treeNode) => + if (inputBids.Contains(bid) || outputBids.Contains(bid)) { - if (treeNode is not TileNode tileNode || tileNode.Level == 1) - { - return; - } - - foreach (var (bid, binfo) in tileNodeMemo[tileNode].BufferInfoMap) - { - foreach (var child in tileNode.Children.ToArray().OfType()) - { - var cbinfo = tileNodeMemo[child].BufferInfoMap[tileNodeMemo[child].GetCacheBid(bid)]; - for (int sl = 0; sl < tileNode.Level - 1; sl++) - { - var cons = solver.MakeLessOrEqual(binfo.Places[^1][sl] + cbinfo.Places[0][sl], 1); - duplictePlaceConstranits.Add(cons); - solver.Add(cons); - } - } - } - }); + var cons = solver.MakeEquality(binfo.Places[0][^1], 1); + cons.SetName($"{bid}StoreAtOutMost"); + solver.Add(cons); + toplevelStoreBufferConstraints.Add(cons); + } + } // 1. must have one buffer at lowest store level. // Beside the top-level node, from bottom to top count each tile node's buffer numbers which are stored at the lowest level. - var eachLevelStoreBufferNums = new Dictionary>>(); + var tileNodeStoreAtLevelPlaces = new Dictionary>>>(); + var reusedBuffers = new HashSet(); primTree.Walk( treeNode => { @@ -192,50 +176,144 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer var tileNodeInfo = tileNodeMemo[tileNode]; - if (!eachLevelStoreBufferNums.TryGetValue(tileNode, out var nodeStoreBufferNums)) + if (!tileNodeStoreAtLevelPlaces.TryGetValue(tileNode, out var curNodeStoreAtLevelePlaces)) { - nodeStoreBufferNums = new Dictionary>(); + curNodeStoreAtLevelePlaces = new Dictionary>>(); + reusedBuffers.UnionWith(tileNodeInfo.DefUseMap.Keys.Select(b => new NodeWithBuffer(tileNode, b))); foreach (var (bid, bufferInfo) in tileNodeInfo.BufferInfoMap) { - var levelStoreNums = new Dictionary(); + var levelPlaces = new Dictionary>(); + + // collect current node‘s placements. for (int sl = 0; sl < tileNode.Level; sl++) { - levelStoreNums[sl] = solver.MakeSum(bufferInfo.Places.Select(p => p[sl].Var()).ToArray()); - } + if (!levelPlaces.TryGetValue(sl, out var places)) + { + places = new List(); + levelPlaces.Add(sl, places); + } - nodeStoreBufferNums.Add(bid, levelStoreNums); - } + foreach (var place in bufferInfo.Places) + { + if (sl < place.Length) + { + places.Add(place[sl]); + } + } + } - foreach (var child in tileNode.Children.ToArray().OfType()) - { - foreach (var (cbid, cbufferInfo) in tileNodeMemo[child].BufferInfoMap) + // collect child node's placement + foreach (var childNode in tileNode.Children.ToArray().OfType()) { - var pbid = tileNodeInfo.GetCacheBid(cbid); - for (int sl = 0; sl < child.Level; sl++) + var childNodeStoreAtLevelePlaces = tileNodeStoreAtLevelPlaces[childNode]; + if (tileNodeInfo.DefUseMap.ContainsKey(bid) || tileNodeInfo.DefUseMap.ContainsValue(bid)) + { + continue; + } + + // collect the child buffer's placement which has not been reused. + if (childNodeStoreAtLevelePlaces.TryGetValue(bid, out var childLevelPlaces)) { - nodeStoreBufferNums[pbid][sl] = nodeStoreBufferNums[pbid][sl] + eachLevelStoreBufferNums[child][cbid][sl]; + for (int sl = 0; sl < childNode.Level; sl++) + { + levelPlaces[sl].AddRange(childLevelPlaces[sl]); + } } } + + curNodeStoreAtLevelePlaces.Add(bid, levelPlaces); } - eachLevelStoreBufferNums.Add(tileNode, nodeStoreBufferNums); + tileNodeStoreAtLevelPlaces.Add(tileNode, curNodeStoreAtLevelePlaces); } }, true); + // sum(places[cl,bid,ci,sl], (cl, ci)) == 1 var eachLevelStoreBufferNumsConstrains = new Dictionary(); foreach (var (bid, bufferInfo) in tileNodeMemo[primTree].BufferInfoMap) { + if (reusedBuffers.Contains(new NodeWithBuffer(primTree, bid))) + { + continue; + } + + var levelPlaces = tileNodeStoreAtLevelPlaces[primTree][bid]; var cons = new Constraint[primTree.Level]; eachLevelStoreBufferNumsConstrains[bid] = cons; for (int sl = 0; sl < primTree.Level; sl++) { - cons[sl] = solver.MakeEquality(eachLevelStoreBufferNums[primTree][bid][sl], 1); - cons[sl].SetName($"store[{bid}, sl{sl}]"); - solver.Add(cons[sl]); + if (levelPlaces.TryGetValue(sl, out var places)) + { + cons[sl] = solver.MakeEquality(solver.MakeSum(places.Select(e => e.Var()).ToArray()), 1); + cons[sl].SetName($"store[{bid}, sl{sl}]"); + solver.Add(cons[sl]); + } } } + var eachLevelStoreReusedBufferNumsConstrains = new Dictionary(); + foreach (var (tileNode, bid) in reusedBuffers) + { + var fusedLevel = tileNode.Level - 1; + + // child's places + var producerSubPlaces = new List(); + var consumerSubPlaces = new List(); + var nodeInfo = tileNodeMemo[tileNode]; + var sourceId = bid; + var sinkId = nodeInfo.DefUseMap[sourceId]; + foreach (var childNode in tileNode.Children.ToArray().OfType()) + { + var childNodeInfo = tileNodeMemo[childNode]; + foreach (var (cbid, cbidInfo) in childNodeInfo.BufferInfoMap) + { + if (cbid == sourceId) + { + producerSubPlaces.AddRange(tileNodeStoreAtLevelPlaces[childNode][cbid][fusedLevel - 1]); + } + else if (cbid == sinkId) + { + consumerSubPlaces.AddRange(tileNodeStoreAtLevelPlaces[childNode][cbid][fusedLevel - 1]); + } + } + } + + // 1. child consumer sub places == child producer sub places == 0 + var producerChildStoreNums = solver.MakeSum(producerSubPlaces.Select(e => e.Var()).ToArray()); + var consumerChildStoreNums = solver.MakeSum(consumerSubPlaces.Select(e => e.Var()).ToArray()); + var pcons = solver.MakeEquality(producerChildStoreNums, 0); + pcons.SetName($"producer_store[{bid}, sl{fusedLevel}]"); + solver.Add(pcons); + var ccons = solver.MakeEquality(consumerChildStoreNums, 0); + ccons.SetName($"consumer_store[{bid}, sl{fusedLevel}]"); + solver.Add(ccons); + + // 2. all parent places == 0 + var parentPlaces = new List(); + var nextNode = tileNode; + while (nextNode.Parent is TileNode nextParent) + { + for (int sl = tileNode.Level - 1; sl < primTree.Level; sl++) + { + parentPlaces.AddRange(tileNodeStoreAtLevelPlaces[nextNode][bid][sl]); + } + + nextNode = nextParent; + } + + var parentCons = solver.MakeEquality(solver.MakeSum(parentPlaces.Select(e => e.Var()).ToArray()), 0); + parentCons.SetName($"fused_parent_store[{bid}]"); + solver.Add(parentCons); + + // 3. fused places == 1 + var fusedStoreNums = solver.MakeSum(tileNodeStoreAtLevelPlaces[tileNode][bid][fusedLevel - 1].Select(e => e.Var()).ToArray()); + var fcons = solver.MakeEquality(fusedStoreNums, 1); + fcons.SetName($"fused_store[{bid}, sl{fusedLevel}]"); + solver.Add(fcons); + eachLevelStoreReusedBufferNumsConstrains.Add(new NodeWithBuffer(tileNode, bid), new[] { pcons, ccons, parentCons, fcons }); + } + // 4. tile var constraints var tileVarConstraints = new Dictionary(); foreach (var opNode in opNodeMemo.Keys) @@ -253,55 +331,49 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer } // 5. add the memory schedule constraints, each level has own memory plan schedule. + // 5.1. sum(place[cl,b,ci,sl]*size[cl,b,ci], sl), sl = [0,toplevel) var levelBufferSizes = new Dictionary>(); var levelBufferLifeness = new Dictionary>>(); + var levelBufferLifenessConstraints = new Dictionary(); for (int sl = 0; sl < topLevel - 1; sl++) { // note currently there is a only one root var nodeBufferSizes = levelBufferSizes[sl] = new(); var nodeBufferLiveness = levelBufferLifeness[sl] = new(); - var rootNodeInfo = tileNodeMemo[primTree]; var beginTime = int.MaxValue; var endTime = int.MinValue; - foreach (var (bid, bufferInfo) in rootNodeInfo.BufferInfoMap) - { - var extents = bufferInfo.Places.Select(p => p[sl]).Zip(bufferInfo.SizeVars).Select(p => p.First * p.Second).ToArray(); - nodeBufferSizes[new(primTree, bid)] = extents.Skip(1).Aggregate(extents[0], solver.MakeSum); - nodeBufferLiveness[new(primTree, bid)] = bufferInfo.Liveness; - beginTime = Math.Min(beginTime, bufferInfo.Liveness.Item1); - endTime = Math.Max(endTime, bufferInfo.Liveness.Item2); - } - - primTree.Walk(current => + foreach (var (tileNode, nodeInfo) in tileNodeMemo) { - if (ReferenceEquals(current, primTree)) + foreach (var (bid, bufferInfo) in nodeInfo.BufferInfoMap) { - return; - } + var nodeBuffer = new NodeWithBuffer(tileNode, bid); + nodeBufferLiveness[nodeBuffer] = bufferInfo.Liveness; + beginTime = Math.Min(beginTime, bufferInfo.Liveness.Item1); + endTime = Math.Max(endTime, bufferInfo.Liveness.Item2); + var extents = new List(); + for (int ci = 0; ci < bufferInfo.Places.Length; ci++) + { + if (sl >= bufferInfo.Places[ci].Length) + { + continue; + } - if (current is not TileNode childNode || childNode.Level <= sl) - { - return; - } + extents.Add(solver.MakeProd(bufferInfo.Places[ci][sl], bufferInfo.SizeVars[ci])); + } - foreach (var (cbid, childBufferInfo) in tileNodeMemo[childNode].BufferInfoMap) - { - // accumulate the extents - var extents = childBufferInfo.Places.Select(p => p[sl]).Zip(childBufferInfo.SizeVars).Select(p => p.First * p.Second).ToArray(); - nodeBufferSizes[new(childNode, cbid)] = extents.Skip(1).Aggregate(extents[0], solver.MakeSum); - nodeBufferLiveness[new(childNode, cbid)] = childBufferInfo.Liveness; - beginTime = Math.Min(beginTime, childBufferInfo.Liveness.Item1); - endTime = Math.Max(endTime, childBufferInfo.Liveness.Item2); + nodeBufferSizes[nodeBuffer] = extents.Skip(1).Aggregate(extents[0], solver.MakeSum); } - }); + } // Add constraints according to liveness. -#if false - DumpGantt(nodeBufferSizes, nodeBufferLiveness, primTree, sl); -#endif + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + DumpGantt(nodeBufferSizes, nodeBufferLiveness, primTree, sl); + } var lastTimeStamp = new HashSet(); + var constraints = new List(); for (int i = beginTime; i <= endTime; i++) { var curTimeStamp = new HashSet(); @@ -318,11 +390,15 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer var bufs = curTimeStamp.Select(key => nodeBufferSizes[key]).ToArray(); var size = bufs.Skip(1).Aggregate(bufs.First(), solver.MakeSum); var cons = solver.MakeLessOrEqual(size, memoryCapacities[sl]); + cons.SetName($"capacity[sl{sl}, t{i}]"); solver.Add(cons); + constraints.Add(cons); lastTimeStamp.Clear(); // update last stamp. lastTimeStamp.UnionWith(curTimeStamp); } } + + levelBufferLifenessConstraints.Add(sl, constraints.ToArray()); } // compute the cycles as objective @@ -348,7 +424,7 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer // computeCycles = solver.MakeSum(computeCycles, noContiguous.Aggregate(opCycles, solver.MakeSum) * loopTrip); } - // from top to down. + // Because of the placement as a control of data movement, there is no need to pick the buffer carefully. var levelDataReads = Enumerable.Range(0, topLevel).Select(i => (IntExpr)solver.MakeIntConst(0)).ToArray(); var levelDataWrites = Enumerable.Range(0, topLevel).Select(i => (IntExpr)solver.MakeIntConst(0)).ToArray(); foreach (var (tileNode, nodeInfo) in tileNodeMemo) @@ -356,15 +432,21 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer var createLevel = tileNode.Level; var nodeWrites = Enumerable.Range(0, topLevel).Select(_ => new List()).ToArray(); var nodeReads = Enumerable.Range(0, topLevel).Select(_ => new List()).ToArray(); - MicroKernelBufferInfo[]? binfo = null; foreach (var (bid, bufferInfo) in nodeInfo.BufferInfoMap) { - binfo ??= bid.Node.GetKernelInfo(targetOptions).BufferInfos; - for (int storeLevel = 0; storeLevel < bufferInfo.Places[0].Length; storeLevel++) + var binfo = bid.Node.GetKernelInfo(targetOptions).BufferInfos; + var reused = nodeInfo.DefUseMap.ContainsKey(bid); + for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++) { - var volumes = new IntExpr[bufferInfo.Places.Length]; + // skip the buffer which store at top level + var volumes = Enumerable.Repeat((IntExpr)solver.MakeIntConst(1), bufferInfo.Places.Length).ToArray(); for (int i = 0; i < bufferInfo.Places.Length; i++) { + if (storeLevel >= bufferInfo.Places[i].Length) + { + continue; + } + IntExpr factor = solver.MakeIntConst(1); // if (storeLevel == 0 && !bid.IsOutput) @@ -382,26 +464,26 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer if (binfo[bid.Index].State.HasFlag(MicroKernelBufferInfo.BufferState.Read)) { - if (storeLevel < topLevel - 1) + if (storeLevel < topLevel) { nodeWrites[storeLevel].Add(dataMoves); // write to store level. } - if (storeLevel < topLevel - 1) + if (storeLevel + 1 < topLevel && !reused) { - nodeReads[storeLevel + 1].Add(dataMoves); // read from create level. + nodeReads[storeLevel + 1].Add(dataMoves); // read from higher level. } } // todo the intermediate buffer should be read write. if (binfo[bid.Index].State.HasFlag(MicroKernelBufferInfo.BufferState.Write)) { - if (storeLevel < topLevel - 1) + if (storeLevel + 1 < topLevel && !reused) { nodeWrites[storeLevel + 1].Add(dataMoves); } - if (storeLevel < topLevel - 1) + if (storeLevel < topLevel) { nodeReads[storeLevel].Add(dataMoves); } @@ -426,8 +508,7 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer var memoryCycles = new IntExpr[topLevel]; for (int i = 0; i < topLevel; i++) { - // memoryCycles[i] = (levelDataWrites[i] + levelDataReads[i]).CeilDiv(memoryBandWidths[i]); - memoryCycles[i] = levelDataWrites[i].CeilDiv(memoryBandWidths[i]); + memoryCycles[i] = (levelDataWrites[i] + levelDataReads[i]).CeilDiv(memoryBandWidths[i]); } IntExpr totalCycles = computeCycles; @@ -464,15 +545,16 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer foreach (var (node, info) in tileNodeMemo) { collector.Add(info.TripCounts.Select(i => i.Var()).ToArray()); + collector.Add(info.BackWardExtents.Select(i => i.Select(j => j.Var())).SelectMany(i => i).ToArray()); foreach (var (bid, bufferInfo) in info.BufferInfoMap) { var placeVars = bufferInfo.Places.SelectMany(i => i).ToArray(); searchAbleVars.AddRange(placeVars.Select(i => i.Var())); collector.Add(placeVars.Select(i => i.Var()).ToArray()); collector.Add(bufferInfo.Shapes.SelectMany(i => i).Select(i => i.Var()).ToArray()); - collector.Add(bufferInfo.SizeVars.Select(i => i.Var()).ToArray()); - collector.Add(bufferInfo.SizeExprs.Select(i => i.Var()).ToArray()); - collector.Add(bufferInfo.Trips.Select(i => i.Var()).ToArray()); + collector.Add(bufferInfo.SizeVars.Where(v => v is not null).Select(i => i.Var()).ToArray()); + collector.Add(bufferInfo.SizeExprs.Where(v => v is not null).Select(i => i.Var()).ToArray()); + collector.Add(bufferInfo.Trips.Where(v => v is not null).Select(i => i.Var()).ToArray()); } } @@ -484,6 +566,14 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer } } + foreach (var (_, v) in levelBufferLifenessConstraints) + { + foreach (var item in v) + { + collector.Add(item.Var()); + } + } + var defaultPhaseParameters = new DefaultPhaseParameters(); if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) { @@ -507,7 +597,7 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer } } - var solve_max_solutions = 10; + var solve_max_solutions = 15; if (System.Environment.GetEnvironmentVariable("TILING_SOLVE_MAX_SOLUTIONS") is string s_solve_max_solutions) { try @@ -551,10 +641,10 @@ private TreeSolveResult SolvePrimGraph(TileNode primTree, BufferGraph primBuffer DumpAssgin(primTree, new TreeSolverPythonPrinter(sol, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, targetOptions), tileVarConstraints, eachLevelStoreBufferNumsConstrains, levelBufferSizes, levelDataReads, levelDataWrites, memoryCycles, computeCycles, totalCyclesVar); } - return new TreeSolveResult(primBufferGraph, sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions, moduleKind); + return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions, moduleKind); } - private void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Dictionary tileVarConstraints, Dictionary lowestStoreBufferNumsConstrains, Dictionary> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles) + public static void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Dictionary tileVarConstraints, Dictionary lowestStoreBufferNumsConstrains, Dictionary> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles) { using (var stream = Diagnostics.DumpScope.Current.OpenFile($"modeling.py")) { @@ -564,7 +654,7 @@ private void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Diction } } - private void DumpAssgin(ITreeNode tree, TreeSolverPrinter printer, Dictionary tileVarConstraints, Dictionary eachLevelStoreBufferNumsConstrains, Dictionary> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles) + public static void DumpAssgin(ITreeNode tree, TreeSolverPrinter printer, Dictionary tileVarConstraints, Dictionary eachLevelStoreBufferNumsConstrains, Dictionary> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles) { using (var stream = Diagnostics.DumpScope.Current.OpenFile($"modeling.yaml")) { @@ -615,7 +705,7 @@ private void DumpAssgin(ITreeNode tree, TreeSolverPrinter printer, Dictionary nodeBufferSizes, Dictionary> nodeBufferLiveness, TileNode primTree, int storeLevel) + private static void DumpGantt(Dictionary nodeBufferSizes, Dictionary> nodeBufferLiveness, TileNode primTree, int storeLevel) { string GetStartStr(string name, int start) => $"[{name}] starts D+{start}"; string GetDurationStr(string name, int duration) => $"[{name}] requires {duration} days"; @@ -637,4 +727,8 @@ private void DumpGantt(Dictionary nodeBufferSizes, Dict writer.WriteLine("```"); } } + + public sealed record TiledFunc(PrimFunctionWrapper Func, long ObjectValue) + { + } } diff --git a/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/IEnvironmentState.cs b/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/IEnvironmentState.cs new file mode 100644 index 000000000..8690196e2 --- /dev/null +++ b/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/IEnvironmentState.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; + +namespace Nncase.Schedule.MonteCarloTreeSearch; + +public interface IEnvironmentState + where TAction : class +{ + int LegalActions(); + + TAction GetNextAction(int index); + + IEnvironmentState? PerformAction(TAction action); + + double RollOut(); +} diff --git a/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/SearchNode.cs b/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/SearchNode.cs new file mode 100644 index 000000000..6770704ef --- /dev/null +++ b/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/SearchNode.cs @@ -0,0 +1,44 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; + +namespace Nncase.Schedule.MonteCarloTreeSearch; + +public abstract class SearchNode + where T : class +{ + public SearchNode(IEnvironmentState state) + { + Parent = null; + Children = new List>(); + VisitTimes = 0; + QualityValue = 0.0; + State = state; + } + + public SearchNode(SearchNode parent, IEnvironmentState state) + { + Parent = parent; + Children = new List>(); + VisitTimes = 0; + QualityValue = 0.0; + State = state; + Parent.Children.Add(this); + } + + public SearchNode? Parent { get; } + + public List> Children { get; } + + public int VisitTimes { get; set; } + + public double QualityValue { get; set; } + + public IEnvironmentState State { get; } + + public bool IsRootNode => Parent is null; + + public abstract void Update(double reward); +} diff --git a/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/Searcher.cs b/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/Searcher.cs new file mode 100644 index 000000000..382d0bd27 --- /dev/null +++ b/src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/Searcher.cs @@ -0,0 +1,47 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; + +namespace Nncase.Schedule.MonteCarloTreeSearch; + +public abstract class Searcher + where T : class +{ + public Searcher(int searchTimes = 20) + { + SearchTimes = searchTimes; + } + + public int SearchTimes { get; } + + public void Search(SearchNode rootNode) + { + for (int i = 0; i < SearchTimes; i++) + { + if (!Selection(rootNode, out var node)) + { + return; + } + + var expanded = Expand(node); + if (expanded is not null) + { + BackPropagate(expanded, Simulation(expanded)); + } + else + { + BackPropagate(node, double.PositiveInfinity); + } + } + } + + public abstract bool Selection(SearchNode node, out SearchNode selected); + + public abstract SearchNode? Expand(SearchNode node); + + public abstract double Simulation(SearchNode node); + + public abstract void BackPropagate(SearchNode node, double reward); +} diff --git a/src/Nncase.Schedule/Schedule/OrToolsExtensions.cs b/src/Nncase.Schedule/Schedule/OrToolsExtensions.cs index 43451a504..90aca73f1 100644 --- a/src/Nncase.Schedule/Schedule/OrToolsExtensions.cs +++ b/src/Nncase.Schedule/Schedule/OrToolsExtensions.cs @@ -52,7 +52,10 @@ public static long[] Value(this Assignment sol, IntExpr[] inputs) var vec = new long[inputs.Length]; for (int i = 0; i < inputs.Length; i++) { - vec[i] = sol.Value(inputs[i].Var()); + if (inputs[i] is not null) + { + vec[i] = sol.Value(inputs[i].Var()); + } } return vec; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/BufferizationAlgorithm.cs b/src/Nncase.Schedule/Schedule/TileGraph/BufferizationAlgorithm.cs index 65ad4d9a7..6bdb599e3 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/BufferizationAlgorithm.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/BufferizationAlgorithm.cs @@ -63,9 +63,9 @@ private void Visit(TieredTileGraph rootGraph) { if (!BufferGraphMemo.TryGetValue(rootGraph, out _)) { - var wrappedGraph = new AdjacencyGraph>(); + var wrappedGraph = new AdjacencyGraph>(allowParallelEdges: false); var rootBufferGraph = new BufferGraph(rootGraph.Level, wrappedGraph); - Visit(rootGraph, rootBufferGraph); + Visit(rootGraph, rootBufferGraph, rootGraph); foreach (var edge in rootGraph.Edges) { var source = new BufferIdentity(edge.Source, edge.Source.ReadAccesses.Length); @@ -77,12 +77,14 @@ private void Visit(TieredTileGraph rootGraph) } } - private void Visit(TieredTileGraph graph, BufferGraph bufferGraph) + private HashSet Visit(TieredTileGraph graph, BufferGraph bufferGraph, TieredTileGraph rootGraph) { + var opnodes = new HashSet(); if (graph.ClustersCount == 0) { foreach (var item in graph.Vertices) { + opnodes.Add(item); var outBid = new BufferIdentity(item, item.ReadAccesses.Length); for (int i = 0; i < item.ReadAccesses.Length; i++) { @@ -97,10 +99,22 @@ private void Visit(TieredTileGraph graph, BufferGraph bufferGraph) if (!BufferGraphMemo.TryGetValue(graph, out _)) { var childBufferGraph = bufferGraph.CreateCluster(childGraph.Level, childGraph.OpId); - Visit(childGraph, childBufferGraph); + opnodes.UnionWith(Visit(childGraph, childBufferGraph, rootGraph)); BufferGraphMemo.Add(childGraph, childBufferGraph); } } + + foreach (var edge in rootGraph.Edges) + { + if (opnodes.Contains(edge.Source) && opnodes.Contains(edge.Target)) + { + var source = new BufferIdentity(edge.Source, edge.Source.ReadAccesses.Length); + var target = new BufferIdentity(edge.Target, edge.Tag); + bufferGraph.AddEdge(new(source, target, BufferEdgeKind.Outer)); + } + } } + + return opnodes; } } diff --git a/src/Nncase.Schedule/Schedule/TileGraph/GraphExtensions.cs b/src/Nncase.Schedule/Schedule/TileGraph/GraphExtensions.cs index 23cc3bc66..f27f9211c 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/GraphExtensions.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/GraphExtensions.cs @@ -116,6 +116,37 @@ public static bool Merge(this TieredTileGraph graph, MergePoint mergePoint) return merger.Visit(graph); } + public static List GetMergePoints(this TieredTileGraph graph) + { + var mergePoints = new List(); + if (graph.Level != -1) + { + throw new InvalidOperationException("only can merge at top level!"); + } + + var children = graph.Clusters.OfType().ToArray(); + foreach (var producer in children) + { + foreach (var comsumer in children) + { + if (ReferenceEquals(producer, comsumer)) + { + continue; + } + + foreach (var edge in graph.Edges) + { + if (comsumer.ContainsVertex(edge.Source) && producer.ContainsVertex(edge.Target)) + { + mergePoints.Add(new(edge.Target, edge.Source, producer.Level)); + } + } + } + } + + return mergePoints; + } + public static void Walk(this TieredTileGraph graph, Action func, bool postOrder = false) { if (!postOrder) diff --git a/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs b/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs new file mode 100644 index 000000000..856e743fa --- /dev/null +++ b/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs @@ -0,0 +1,247 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections; +using System.Diagnostics.CodeAnalysis; +using System.Reactive; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.IR.Affine; +using Nncase.Schedule.MonteCarloTreeSearch; +using QuikGraph; +using QuikGraph.Algorithms; +using QuikGraph.Algorithms.ShortestPath; +using QuikGraph.Graphviz; + +namespace Nncase.Schedule.TileGraph; + +public sealed class MCTState : IEnvironmentState +{ + private readonly string _path = string.Empty; + + private readonly List _mergePoints = new(); + + private readonly List _legalIndex = new(); + + private readonly Dictionary _solveMemo; + + private readonly string _moduleKind; + + private readonly string _prefix; + + private readonly ICpuTargetOptions _targetOptions; + + private readonly TieredTileGraph _graph; + + private int _permformCount; + + public MCTState(TieredTileGraph graph, string moduleKind, string prefix, string searchPath, Dictionary solveMemo, ICpuTargetOptions targetOptions) + { + _graph = graph; + _moduleKind = moduleKind; + _prefix = prefix; + _solveMemo = solveMemo; + _targetOptions = targetOptions; + _mergePoints.AddRange(graph.GetMergePoints()); + _legalIndex.AddRange(Enumerable.Range(0, _mergePoints.Count)); + _path = searchPath; + Results = new(new LeafTileGraphComparer()); + } + + public long ObjectValue { get; private set; } + + public Dictionary Results { get; } + + public MergePoint GetNextAction(int index) + { + var legalIndex = _legalIndex[index]; + _legalIndex.Remove(legalIndex); + _permformCount++; + return _mergePoints[legalIndex]; + } + + public int LegalActions() + { + return _legalIndex.Count; + } + + public IEnvironmentState? PerformAction(MergePoint mergePoint) + { + var newGraph = _graph.Clone(); + if (newGraph.Merge(mergePoint)) + { + return new MCTState(newGraph, _moduleKind, _prefix, $"{_path}.{_permformCount}", _solveMemo, _targetOptions); + } + + return null; + } + + public double RollOut() + { + if (ObjectValue == 0) + { + using var scope = new Diagnostics.DumpScope($"RollOut{_path}"); + try + { + var res = GraphTiler.SolveRootGraph(_graph, _moduleKind, _prefix, _solveMemo, _targetOptions); + ObjectValue = res.ObjectValue; + foreach (var item in res.ResultMemo) + { + Results.Add(item.Key, item.Value); + } + } + catch (System.Exception) + { + ObjectValue = long.MaxValue; + return ObjectValue; + } + } + + return ObjectValue; + } + + private sealed class LeafTileGraphComparer : IEqualityComparer + { + public bool Equals(TieredTileGraph? x, TieredTileGraph? y) => (x, y) switch + { + (null, null) => true, + (TieredTileGraph a, TieredTileGraph b) => a.OpId.Equals(b.OpId) && a.Level.Equals(b.Level), + _ => false, + }; + + public int GetHashCode([DisallowNull] TieredTileGraph obj) => HashCode.Combine(obj.OpId, obj.Level); + } +} + +public sealed class MCTNode : SearchNode +{ + public MCTNode(IEnvironmentState state) + : base(state) + { + Action = null; + QualityValue = double.PositiveInfinity; + } + + public MCTNode(SearchNode parent, IEnvironmentState state, MergePoint action) + : base(parent, state) + { + QualityValue = double.PositiveInfinity; + Action = action; + } + + public MergePoint? Action { get; } + + public override void Update(double reward) + { + if (QualityValue > reward) + { + QualityValue = reward; + } + + VisitTimes += 1; + + if (Parent is not null) + { + Parent.Update(reward); + } + } + + public void Dump(string name) + { + using (var file = Diagnostics.DumpScope.Current.OpenFile($"{name}.yaml")) + { + using var baseWriter = new StreamWriter(file); + using var writer = new System.CodeDom.Compiler.IndentedTextWriter(baseWriter, " "); + Dump(writer); + } + } + + public void Dump(System.CodeDom.Compiler.IndentedTextWriter writer) + { + writer.WriteLine($"- name: {this}"); + writer.WriteLine($" Action: {Action}"); + writer.WriteLine($" QualityValue: {QualityValue}"); + writer.WriteLine($" VisitTimes: {VisitTimes}"); + writer.WriteLine($" Children:"); + writer.Indent += 1; + foreach (var item in Children.OfType()) + { + item.Dump(writer); + } + + writer.Indent -= 1; + } +} + +public sealed class MCTSearcher : Searcher +{ + private readonly Random _random = new Random(1010); + + public MCTSearcher() + { + BestObjectValue = double.PositiveInfinity; + BestMCTNode = null; + } + + public double BestObjectValue { get; private set; } + + public MCTNode? BestMCTNode { get; private set; } + + public SearchNode UCBSelectChild(SearchNode node) + { + double coef = Math.Sqrt(2); + double temp = 0.5; + var ucbs = node.Children.Select(c => (-c.QualityValue / BestObjectValue) + (coef * Math.Sqrt(Math.Log(node.VisitTimes) / c.VisitTimes))).ToArray(); + var ucbs_exp = ucbs.Select(ucb => Math.Exp(ucb / temp)).ToArray(); + var sum = ucbs_exp.Sum(); + var probs = ucbs_exp.Select(e => (int)(e / sum * 30)).ToArray(); // conver ucb as prob + var candidates = probs.Select((p, i) => Enumerable.Repeat(i, p).ToArray()).SelectMany(i => i).ToArray(); + return node.Children[candidates[_random.Next(candidates.Length)]]; + } + + public override bool Selection(SearchNode node, out SearchNode selected) + { + while (node.State.LegalActions() == 0 && node.Children.Count > 0) + { + node = UCBSelectChild(node); + } + + selected = node; + return true; + } + + public override SearchNode? Expand(SearchNode node) + { + if (node.VisitTimes != 0 && node.State.LegalActions() > 0) + { + var index = _random.Next(node.State.LegalActions()); + var action = node.State.GetNextAction(index); + var state = node.State.PerformAction(action); + if (state is not null) + { + return new MCTNode(node, state, action); + } + + return null; + } + + return node; + } + + public override double Simulation(SearchNode node) + { + double value = node.State.RollOut(); + if (value < BestObjectValue) + { + BestObjectValue = value; + BestMCTNode = (MCTNode)node; + } + + return value; + } + + public override void BackPropagate(SearchNode node, double reward) + { + node.Update(reward); + } +} diff --git a/src/Nncase.Schedule/Schedule/TileGraph/GraphSolverTypes.cs b/src/Nncase.Schedule/Schedule/TileGraph/GraphSolverTypes.cs index 0edde9d30..3b9f81776 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/GraphSolverTypes.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/GraphSolverTypes.cs @@ -26,6 +26,7 @@ public sealed record BufferIdentity(TileGrid Node, int Index) /// /// Places[create loop][store level]: /// create loop in [0, domain rank] , 0 means out all, 1 means out loop0, domain rank means in loopN. +/// note only the nodes which store at top level have valid Places[0], else the Places[0] is empty. /// store level in [0, create level == top level ? create level : top level - 1), 0 means level 1, 1 means level 2. /// the buffer shape according to the placement. /// the buffer size according to the placement. @@ -47,33 +48,19 @@ public sealed record TileNodeBufferInfo(Tuple Liveness, AffineMap M /// buffer info memo. public sealed record TileNodeInfo(T[] TripCounts, T[][] BackWardExtents, Dictionary DefUseMap, Dictionary> BufferInfoMap) { - public BufferIdentity GetCacheBid(BufferIdentity bid) + public BufferIdentity GetByChildBuffer(BufferIdentity cbid) { - if (DefUseMap.TryGetValue(bid, out var sinkId)) + if (DefUseMap.Values.Contains(cbid)) { - return sinkId; + return DefUseMap.Where(kv => kv.Value == cbid).First().Key; } - if (!BufferInfoMap.ContainsKey(bid)) + if (!BufferInfoMap.ContainsKey(cbid)) { - throw new KeyNotFoundException(bid.ToString()); + throw new KeyNotFoundException(cbid.ToString()); } - return bid; - } - - public bool TryGetBufferInfo(BufferIdentity bid, [MaybeNullWhen(false)] out TileNodeBufferInfo info) - { - if (DefUseMap.TryGetValue(bid, out var sinkId)) - { - BufferInfoMap.TryGetValue(sinkId, out info); - } - else - { - BufferInfoMap.TryGetValue(bid, out info); - } - - return info is not null; + return cbid; } } diff --git a/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs b/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs index 5eb7c686b..8d7bcabeb 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs @@ -2,11 +2,12 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System.Reactive; -using Google.OrTools.ConstraintSolver; +using Google.OrTools.Sat; using Nncase.IR; using Nncase.IR.Affine; using Nncase.TIR; using Nncase.TIR.Builders; +using static Nncase.TIR.TIRExtensions; namespace Nncase.Schedule.TileGraph; @@ -101,7 +102,7 @@ public Unit Visit(TileNode value, Context context) if (!(value.Level == PrimBufferGraph.Level && i == 0 && sl == (PrimBufferGraph.Level - 1)) && place[sl] == 1) { var kernelInfo = bid.Node.GetKernelInfo(TargetOptions); - var viewInfo = GetParentSubViewInfo(sl + 1, value, bid, bufferInfo.Map, forwardOffsets[i], bufferInfo.Shapes[i]); + var viewInfo = GetParentSubViewInfo(sl, value, bid, bufferInfo.Map, forwardOffsets[i], bufferInfo.Shapes[i]); Expr subView; if (viewInfo.InnerAllocated) { @@ -230,65 +231,65 @@ public void ScheduleBuffers() foreach (var (level, nodeBufferSizes) in LevelBufferSizes) { var nodeBufferOffsets = LevelBufferOffsets[level] = new(); - var solver = new Solver("buffer scheduler"); - var xstarts = new List(); - var xsizes = new List(); - var ystarts = new List(); - var ysizes = new List(); - var validKeys = new List(); - + var model = new CpModel(); + var rectangles = new Dictionary(); + int count = 0; + var cons = model.AddNoOverlap2D(); foreach (var (key, size) in nodeBufferSizes) { if (size > 0) { - xstarts.Add(solver.MakeIntConst(LevelBufferLifeness[level][key].Item1)); - xsizes.Add(LevelBufferLifeness[level][key].Item2 - LevelBufferLifeness[level][key].Item1); - var ystart = solver.MakeIntVar(0, TargetOptions.MemoryCapacities[level] - size); - ystarts.Add(ystart); + var x = model.NewFixedSizeIntervalVar(LevelBufferLifeness[level][key].Item1, LevelBufferLifeness[level][key].Item2 - LevelBufferLifeness[level][key].Item1, $"x{count}"); + var ystart = model.NewIntVar(0, TargetOptions.MemoryCapacities[level] - size, $"ystart{count}"); if (ModuleKind == "xpu") { - solver.Add(solver.MakeEquality(solver.MakeIntConst(0), solver.MakeModulo(ystart, 128))); + model.AddModuloEquality(0, ystart, 128); } - ysizes.Add(size); - validKeys.Add(key); + var y = model.NewFixedSizeIntervalVar(ystart, size, $"y{count}"); + cons.AddRectangle(x, y); + rectangles.Add(key, (x, y)); + count++; } } - solver.Add(solver.MakeNonOverlappingBoxesConstraint(xstarts.ToArray(), ystarts.ToArray(), xsizes.ToArray(), ysizes.ToArray())); - var collector = solver.MakeFirstSolutionCollector(); - foreach (var item in ystarts) +#if false + // process inplace buffer. + foreach (var (k, (x, y)) in rectangles) { - collector.Add(item); - } + var inplaceMemo = k.Id.Node.Op.GetInPlaceMemo(); + if (!inplaceMemo.TryGetValue(k.Id.Index, out var sourceIndex)) + { + continue; + } - var defaultPhaseParameters = new DefaultPhaseParameters(); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - defaultPhaseParameters.display_level = DefaultPhaseParameters.NORMAL; - } - else - { - defaultPhaseParameters.display_level = DefaultPhaseParameters.NONE; - } + // 1. when source buffer is isolated. we can find it in rectangles. + foreach (var sourceKey in rectangles.Keys.Where(n => ReferenceEquals(n.Node, k.Node) && n.Id.Index == sourceIndex)) + { + model.Add(rectangles[sourceKey].YInterval.StartExpr() == y.StartExpr()); + } - var decisionBuilder = solver.MakeDefaultPhase(ystarts.ToArray(), defaultPhaseParameters); - var monitors = new List() { collector, solver.MakeSolutionsLimit(1), }; - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - monitors.Add(solver.MakeSearchLog(10000)); + // 2. source buffer has been reused. we need find it in defuseMap firstly. + foreach (var (defId, _) in TileNodeMemo[k.Node].DefUseMap.Where(kv => ReferenceEquals(kv.Value.Node, k.Id.Node) && kv.Value.Index == sourceIndex)) + { + foreach (var defKey in rectangles.Keys.Where(n => ReferenceEquals(n.Node, k.Node) && n.Id == defId)) + { + model.Add(rectangles[defKey].YInterval.StartExpr() == y.StartExpr()); + } + } } +#endif - var status = solver.Solve(decisionBuilder, monitors.ToArray()); - if (!status) + var solver = new CpSolver(); + var status = solver.Solve(model); + if (status is not CpSolverStatus.Optimal) { throw new InvalidOperationException("can't schedule buffers!"); } - var sol = collector.Solution(0); - for (int i = 0; i < ystarts.Count; i++) + foreach (var (k, (_, y)) in rectangles) { - nodeBufferOffsets[validKeys[i]] = (ulong)sol.Value(ystarts[i]); + nodeBufferOffsets[k] = (ulong)solver.Value(y.StartExpr()); } } } @@ -328,9 +329,9 @@ private DistributedType GetBufferDistributedType(Expr expr) } /// - /// declare the input/output buffer. + /// get declare of the input/output buffer which was stored on top level. /// - private TIR.Buffer GetParentDeclareBuffer(int storeLevel, ITileable node, BufferIdentity bid) + private TIR.Buffer GetTopLevelDeclareBuffer(BufferIdentity bid) { var expr = bid.Node.Grid.Buffers[bid.Index]; var tensorType = GetBufferTensorType(expr); @@ -361,7 +362,7 @@ private bool TryGetParerntBuffer(ITreeNode node, BufferIdentity bid, out Expr pa var parentNode = node.Parent; while (parentNode is TileNode parentTileNode && parentTileNode.OpId != -1) { - var pbid = TileNodeMemo[parentTileNode].GetCacheBid(cbid); + var pbid = TileNodeMemo[parentTileNode].GetByChildBuffer(cbid); if (_subViewMemo.TryGetValue(parentTileNode, out var subViewMap) && subViewMap.TryGetValue(pbid, out var subViewInfo)) { parentBuffer = subViewInfo.Buffer; @@ -401,15 +402,15 @@ private ParentSubViewInfo GetParentSubViewInfo(int storeLevel, ITreeNode node, B var (outputs, inputs) = PrimBufferGraph.GetInputsOutputs(); if (outputs.Contains(bid)) { - parentBuffer = GetParentDeclareBuffer(storeLevel, node, bid); + parentBuffer = GetTopLevelDeclareBuffer(bid); } else if (inputs.Contains(bid)) { - parentBuffer = GetParentDeclareBuffer(storeLevel, node, bid); + parentBuffer = GetTopLevelDeclareBuffer(bid); } else if (node is TileNode tileNode) { - parentBuffer = GetParentAllocateBuffer(storeLevel, tileNode, bid, shape, out innerAllocated); + parentBuffer = GetInnerAllocateBuffer(storeLevel, tileNode, bid, shape, out innerAllocated); } } @@ -417,9 +418,9 @@ private ParentSubViewInfo GetParentSubViewInfo(int storeLevel, ITreeNode node, B } /// - /// Get the local allocate buffer. + /// Allocate a buffer which store at inner level. /// - private TIR.Buffer GetParentAllocateBuffer(int storeLevel, TileNode node, BufferIdentity bid, int[] shape, out bool innerAllocated) + private TIR.Buffer GetInnerAllocateBuffer(int storeLevel, TileNode node, BufferIdentity bid, int[] shape, out bool innerAllocated) { var expr = bid.Node.Grid.Buffers[bid.Index]; var tensorType = GetBufferTensorType(expr); diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs b/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs index 53ba7d6b5..ffd6deeaa 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs @@ -12,6 +12,7 @@ using Nncase.IR; using Nncase.IR.Affine; using QuikGraph; +using QuikGraph.Algorithms; using QuikGraph.Collections; namespace Nncase.Schedule.TileGraph; @@ -105,7 +106,7 @@ private TileNode(ITreeNode? parent, TieredTileGraph wrapped, int childCount) public static TileNode FromTileGraph(TieredTileGraph rootGraph, out Dictionary memo) { memo = new(); - return ConvertToTree(null, rootGraph, memo); + return ConvertToTree(null, rootGraph, rootGraph, memo); } TReturn ITreeNode.Accept(ITreeNodeVisitor visitor, TArg1 arg1) => visitor.Visit(this, arg1); @@ -115,26 +116,66 @@ public override string ToString() return _wrapped.ToString(); } - private static TileNode ConvertToTree(ITreeNode? parent, TieredTileGraph tileGraph, Dictionary memo) + private static TileNode ConvertToTree(ITreeNode? parent, TieredTileGraph tileGraph, TieredTileGraph rootGraph, Dictionary memo) { if (!memo.TryGetValue(tileGraph, out var tnode)) { if (tileGraph.ClustersCount == 0) { + // sort + var tempGraph = new AdjacencyGraph>(allowParallelEdges: false); + var childVertices = tileGraph.Vertices.ToArray(); + tempGraph.AddVertexRange(childVertices); + foreach (var edge in rootGraph.Edges) + { + var producers = childVertices.Where(c => c.Equals(edge.Source)).ToArray(); + var consumers = childVertices.Where(c => c.Equals(edge.Target)).ToArray(); + foreach (var producer in producers) + { + foreach (var consumer in consumers) + { + if (!ReferenceEquals(producer, consumer)) + { + tempGraph.AddEdge(new(producer, consumer)); + } + } + } + } + tnode = new TileNode(parent, tileGraph, tileGraph.VertexCount); int count = 0; - foreach (var item in tileGraph.Vertices) + foreach (var item in tempGraph.TopologicalSort()) { tnode._children[count++] = new OpNode(tnode, item); } } else { + // sort child clusters + var tempGraph = new AdjacencyGraph>(allowParallelEdges: false); + var childClusters = tileGraph.Clusters.OfType().ToArray(); + tempGraph.AddVertexRange(childClusters); + foreach (var edge in rootGraph.Edges) + { + var producers = childClusters.Where(c => c.ContainsVertex(edge.Source)).ToArray(); + var consumers = childClusters.Where(c => c.ContainsVertex(edge.Target)).ToArray(); + foreach (var producer in producers) + { + foreach (var consumer in consumers) + { + if (!ReferenceEquals(producer, consumer)) + { + tempGraph.AddEdge(new(producer, consumer)); + } + } + } + } + tnode = new TileNode(parent, tileGraph, tileGraph.ClustersCount); int count = 0; - foreach (var item in tileGraph.Clusters.OfType()) + foreach (var item in tempGraph.TopologicalSort()) { - tnode._children[count++] = ConvertToTree(tnode, item, memo); + tnode._children[count++] = ConvertToTree(tnode, item, rootGraph, memo); } } diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverInitializer.cs b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverInitializer.cs index 2cfddd03f..6cf1febf1 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverInitializer.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverInitializer.cs @@ -6,63 +6,35 @@ using Nncase.IR.Affine; using QuikGraph; using QuikGraph.Graphviz; +using static Nncase.TIR.TIRExtensions; namespace Nncase.Schedule.TileGraph; public sealed class TreeSolverInitializer : TreeSolverBase, ITreeNodeVisitor { - public TreeSolverInitializer(int topLevel, Solver solver, Dictionary> primitiveBufferInfo, Dictionary> levelBufferInfos, Dictionary> domainDimInfos, ICpuTargetOptions targetOptions) + public TreeSolverInitializer(Dictionary bufferGraphMemo, int topLevel, Solver solver, Dictionary> primitiveBufferInfo, Dictionary> levelBufferInfos, Dictionary> domainDimInfos, ICpuTargetOptions targetOptions) : base(solver, primitiveBufferInfo, levelBufferInfos, domainDimInfos, targetOptions) { + BufferGraphMemo = bufferGraphMemo; TopLevel = topLevel; } public int TimeStamp { get; private set; } + public IReadOnlyDictionary BufferGraphMemo { get; } + public int TopLevel { get; } - public static void Init(TileNode tree, int topLevel, ICpuTargetOptions options, out Solver solver, out Dictionary> opNodeMemo, out Dictionary> tileNodeMemo, out Dictionary> tileableNodeMemo) + public static void Init(TileNode tree, Dictionary bufferGraphMemo, int topLevel, ICpuTargetOptions options, out Solver solver, out Dictionary> opNodeMemo, out Dictionary> tileNodeMemo, out Dictionary> tileableNodeMemo) { solver = new Solver("GraphSolver"); opNodeMemo = new Dictionary>(); tileNodeMemo = new Dictionary>(); tileableNodeMemo = new Dictionary>(); - var initializer = new TreeSolverInitializer(topLevel, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, options); + var initializer = new TreeSolverInitializer(bufferGraphMemo, topLevel, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, options); initializer.Visit(tree, Context.Default); } - /// - /// source id => sink id. - /// - public static Dictionary GetBufferDefUseMap(TileNode tilenode, BufferResult[] bufferResults) - { - while (tilenode.Parent is TileNode parent) - { - tilenode = parent; - } - - var map = new Dictionary(); - for (int i = 0; i < bufferResults.Length; i++) - { - var sourceId = bufferResults[i].Bid; - if (tilenode.Wrapped.TryGetOutEdges(sourceId.Node, out var outEdges)) - { - foreach (var outEdge in outEdges) - { - foreach (var target in bufferResults.Where(r => r.Bid.Node == outEdge.Target && r.Bid.Index == outEdge.Tag)) - { - if (!map.ContainsKey(sourceId)) - { - map.Add(sourceId, target.Bid); - } - } - } - } - } - - return map; - } - public InitResult Visit(TileNode value, Context context) { var (pid, pvars, ptrips) = context; @@ -124,34 +96,37 @@ public InitResult Visit(TileNode value, Context context) var backWardExtents = GetBackWardExtents(tileVars, childResult.DimsMaps, childResult.BackWardExtents); - var defUseMap = GetBufferDefUseMap(value, childResult.BufferResults); + // {source id : target id} + var defUseMap = BufferGraphMemo[value.Wrapped].Edges.Where(e => e.Tag == BufferEdgeKind.Outer).ToDictionary(e => e.Source, e => e.Target); var bufferResults = new List(); // each tile node have buffer place vars. if (!TileNodeMemo.TryGetValue(value, out var info)) { var bufferInfoMap = new Dictionary>(); + var reusedIds = new HashSet(childResult.BufferResults.Where(r => defUseMap.ContainsKey(r.Bid)).Select(r => defUseMap[r.Bid])); for (int i = 0; i < childResult.BufferResults.Length; i++) { var result = childResult.BufferResults[i]; - BufferIdentity currentId; + var curId = result.Bid; + if (reusedIds.Contains(curId)) + { + continue; + } + AffineMap currentAccessMap = result.AccessMap; Tuple currentLifeness = result.Lifeness; - if (defUseMap.TryGetValue(result.Bid, out currentId!)) + if (defUseMap.TryGetValue(curId, out var sinkId)) { - var sinkIndex = Array.FindIndex(childResult.BufferResults, r => r.Bid == currentId); + var sinkIndex = Array.FindIndex(childResult.BufferResults, r => r.Bid == sinkId); currentAccessMap = childResult.BufferResults[sinkIndex].AccessMap; currentLifeness = new(Math.Min(result.Lifeness.Item1, childResult.BufferResults[sinkIndex].Lifeness.Item1), Math.Max(result.Lifeness.Item2, childResult.BufferResults[sinkIndex].Lifeness.Item2)); } - else - { - currentId = result.Bid; - } - if (!bufferInfoMap.TryGetValue(currentId, out var bufferInfo)) + if (!bufferInfoMap.TryGetValue(curId, out var bufferInfo)) { - bufferInfoMap.Add(currentId, GetBufferInfo(value, currentId, currentAccessMap, currentLifeness, forwardExtents, backWardExtents)); - bufferResults.Add(new(currentId, currentLifeness, value.DomainRelation.Map * currentAccessMap)); + bufferInfoMap.Add(curId, GetBufferInfo(value, curId, currentAccessMap, currentLifeness, forwardExtents, backWardExtents)); + bufferResults.Add(new(curId, currentLifeness, value.DomainRelation.Map * currentAccessMap)); } } @@ -285,20 +260,20 @@ bool ProductExtent(IntExpr[] extents, int i) private TileNodeBufferInfo GetBufferInfo(TileNode tileNode, BufferIdentity bid, AffineMap accessMap, Tuple lifeness, IntExpr[] forwardExtents, IntExpr[][] backWardExtents) { var rank = tileNode.DomainRelation.Map.Results.Length + 1; - var bufferPlaces = new IntExpr[rank][]; - var bufferShapes = new IntExpr[rank][]; + var bufferPlaces = Enumerable.Range(0, rank).Select(i => Array.Empty()).ToArray(); + var bufferShapes = Enumerable.Range(0, rank).Select(i => Array.Empty()).ToArray(); var bufferSizes = new IntExpr[rank]; var bufferSizeVars = new IntExpr[rank]; var bufferTrips = new IntExpr[rank]; var bufferMasks = new LoopMask[rank]; var resultStr = accessMap.ToString().Split("->")[1]; - for (int i = 0; i < rank; i++) + for (int i = tileNode.Level == TopLevel ? 0 : 1; i < rank; i++) { var subLevelPlace = bufferPlaces[i] = new IntVar[tileNode.Level]; for (int sl = 0; sl < subLevelPlace.Length; sl++) { - subLevelPlace[sl] = Solver.MakeBoolVar($"p[cl{tileNode.Level}, op{tileNode.OpId}, b{bid.Index}, ci{i}, sl{sl}]"); + subLevelPlace[sl] = Solver.MakeBoolVar($"p[cl{tileNode.Level}, op{bid.Node.OpId}, b{bid.Index}, ci{i}, sl{sl}]"); } var subDomainShapes = bufferShapes[i] = new IntExpr[accessMap.Results.Length]; @@ -309,7 +284,7 @@ private TileNodeBufferInfo GetBufferInfo(TileNode tileNode, BufferIdent } bufferSizes[i] = subDomainShapes.Aggregate((IntExpr)Solver.MakeIntConst(bid.Node.Grid.Buffers[bid.Index].CheckedDataType.SizeInBytes), Solver.MakeProd); - bufferSizeVars[i] = Solver.MakeIntVar(1, int.MaxValue, $"size[cl{tileNode.Level}, op{tileNode.OpId}, b{bid.Index}, ci{i}]"); + bufferSizeVars[i] = Solver.MakeIntVar(1, int.MaxValue, $"size[cl{tileNode.Level}, op{bid.Node.OpId}, b{bid.Index}, ci{i}]"); Solver.Add(Solver.MakeEquality(bufferSizeVars[i], bufferSizes[i])); var mask = 0U; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPrinter.cs b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPrinter.cs index 82a772804..621b708a5 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPrinter.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPrinter.cs @@ -34,6 +34,11 @@ public static void WriteIntExprVector(IndentedTextWriter writer, string prefix, writer.Indent++; for (int i = 0; i < intExprs.Length; i++) { + if (intExprs[i] is null) + { + continue; + } + string value = string.Empty; if (solution is Assignment assignment && intExprs[i] is IntExpr expr) { @@ -77,6 +82,7 @@ public Unit Visit(TileNode value, IndentedTextWriter writer) writer.WriteLine($"{bid}:"); { writer.Indent++; + WriteIntExprMatrix(writer, "Places", info.Places, Solution); WriteIntExprMatrix(writer, "Shapes", info.Shapes, Solution); WriteIntExprVector(writer, "SizeVars", info.SizeVars, Solution); WriteIntExprVector(writer, "SizeExprs", info.SizeExprs, Solution); diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs index 44075107d..e629589c1 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs @@ -45,7 +45,20 @@ public Unit Visit(TileNode value, (ITreeNode? Parent, IndentedTextWriter Writer) } else if (parent is TileNode parentTile) { - parentBounds = (int)_bounds[parentTile][i]; + if (i < _bounds[parentTile].Count) + { + parentBounds = (int)_bounds[parentTile][i]; + } + else + { + value.Walk(child => + { + if (child is OpNode opNode && opNode.OpId == value.OpId) + { + parentBounds = opNode.DomainBounds[i]; + } + }); + } } var tile = Solution.Value(domainInfo.TileVars[i].Var()); @@ -97,7 +110,7 @@ public Unit Visit(OpNode value, (ITreeNode? Parent, IndentedTextWriter Writer) c { var (parent, writer) = context; var opinfo = OpNodeMemo[value]; - var shapes = string.Join(", ", opinfo.Shapes.Select((sp, i) => $"buf{i}[" + string.Join(',', sp.Select(s => Solution.Value(s.Var()))) + "]")); + var shapes = string.Join(", ", opinfo.Shapes.Select((sp, i) => $"{new BufferIdentity(value.Wrapped, i)}[" + string.Join(',', sp.Select(s => Solution.Value(s.Var()))) + "]")); var size = string.Join(", ", opinfo.Sizes.Select(s => Solution.Value(s.Var()))); writer.WriteLine($"{value.Op.GetType()}({value.Op.DisplayProperty()}, {shapes}) # size: {size}"); return default; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverWritesInitializer.cs b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverWritesInitializer.cs index 2c438a3c1..536742996 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverWritesInitializer.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverWritesInitializer.cs @@ -38,7 +38,7 @@ public Unit Visit(TileNode value, Dictionary bufferTrip // 1. child domain map to parent domain. foreach (var (bid, bufferInfo) in TileNodeMemo[value].BufferInfoMap) { - var parentTripCounts = partentTileInfo is null ? Solver.MakeIntConst(1) : bufferTripCounts[partentTileInfo.GetCacheBid(bid)]; + var parentTripCounts = partentTileInfo is null ? Solver.MakeIntConst(1) : bufferTripCounts[partentTileInfo.GetByChildBuffer(bid)]; for (int i = 0; i < domainInfo.TileVars.Length + 1; i++) { diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index 120667a65..da7cbc230 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -34,7 +34,7 @@ public AutoTilePass(string moduleKind, CompileOptions compileOptions) protected override Task RunCoreAsync(IRModule input, RunPassContext context) { - var tiler = new GraphTiler(); + var memo = new Dictionary(); var funcNums = input.Functions.Count; for (int i = 0; i < funcNums; i++) { @@ -224,6 +224,11 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph { diff --git a/src/Nncase.Tests/Affine/FunctionSamples.cs b/src/Nncase.Tests/Affine/FunctionSamples.cs index 2fa586185..0d8ff6f3d 100644 --- a/src/Nncase.Tests/Affine/FunctionSamples.cs +++ b/src/Nncase.Tests/Affine/FunctionSamples.cs @@ -36,6 +36,31 @@ public static Function Get1() return func; } + public static Function Get1Matmul() + { + Function func; + { + var a = new Var(new TensorType(DataTypes.Float32, new[] { 128, 256 })); + var b = new Var(new TensorType(DataTypes.Float32, new[] { 256, 384 })); + var c = IR.F.Tensors.MatMul(a, b); + func = new(c, a, b); + } + + return func; + } + + public static Function Get1Exp() + { + Function func; + { + var a = new Var(new TensorType(DataTypes.Float32, new[] { 128, 384 })); + var d = IR.F.Math.Exp(a); + func = new(d, a); + } + + return func; + } + /// /// Tileflow default case with pack M. /// @@ -127,4 +152,21 @@ public static Function Get5() return func; } + + /// + /// get single op for mcts. + /// + public static Function Get6() + { + Function func; + { + var shape = new[] { 1, 12, 14, 14 }; + var a = new IR.Var("a", new IR.TensorType(DataTypes.Float32, shape)); + var b = new IR.Var("b", new IR.TensorType(DataTypes.Float32, shape)); + var c = IR.F.Math.Binary(BinaryOp.Mul, a, b); + func = new IR.Function("main", c, a, b); + } + + return func; + } } diff --git a/src/Nncase.Tests/Affine/UnitTestModeling.cs b/src/Nncase.Tests/Affine/UnitTestModeling.cs index 12fdcdeb6..0098690c5 100644 --- a/src/Nncase.Tests/Affine/UnitTestModeling.cs +++ b/src/Nncase.Tests/Affine/UnitTestModeling.cs @@ -543,6 +543,23 @@ public void TestSolveNoOverlapping() System.Console.WriteLine(sol.Value(aplace)); System.Console.WriteLine(sol.Value(cplace)); } + + [Fact] + public void TestSolveZeroLiveness() + { + // note checked [0,3] is not overlapping with [3,3] + var model = new Google.OrTools.Sat.CpModel(); + var cons = model.AddNoOverlap2D(); + var x1 = model.NewFixedSizeIntervalVar(0, 3, "x"); + var y1 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 16384, "ystart"), 49152, "y"); + var x2 = model.NewFixedSizeIntervalVar(3, 0, "x2"); + var y2 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 16384, "ystart2"), 49152, "y2"); + cons.AddRectangle(x1, y1); + cons.AddRectangle(x2, y2); + var solver = new Google.OrTools.Sat.CpSolver(); + var status = solver.Solve(model); + Assert.Equal(Google.OrTools.Sat.CpSolverStatus.Optimal, status); + } } internal sealed record VTensor(string Name, char[] Dims) diff --git a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs index 5188b2257..0050ad490 100644 --- a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs +++ b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs @@ -33,14 +33,37 @@ public sealed class UnitTestTileGraph : TestClassBase { FunctionSamples.Get1PackMN, new (IntMergePoint, bool)[] { (new(2, 0, 2), true), (new(2, 1, 2), true), (new(2, 0, 1), true), (new(2, 1, 1), true), (new(3, 2, 2), true), (new(5, 4, 2), true) }, MergeTileGraphChecker2, 2 }, }; - public static readonly TheoryData, Action, int> SolveTileGraphDatas = new() + public static readonly TheoryData, IntMergePoint[], Action, int> SolveTileGraphDatas = new() { - { FunctionSamples.Get5, SolveTileGraphChecker0, 3 }, + { FunctionSamples.Get5, [], SolveTileGraphChecker0, 0 }, + { FunctionSamples.Get1, [new(1, 0, 2)], (_) => { }, 1 }, + { FunctionSamples.Get1, [new(2, 1, 2)], (_) => { }, 2 }, + { FunctionSamples.Get1, [new(1, 0, 2), new(2, 1, 2)], (_) => { }, 3 }, + { FunctionSamples.Get4, [new(2, 0, 2)], (_) => { }, 4 }, + + // just for check single op tiling results + { FunctionSamples.Get1Matmul, [], (_) => { }, 5 }, + { FunctionSamples.Get1Exp, [], (_) => { }, 6 }, + }; + + public static readonly TheoryData, int> MCTSDatas = new() + { + { FunctionSamples.Get1, 0 }, + { FunctionSamples.Get4, 1 }, + { FunctionSamples.Get6, 2 }, + }; + + public static readonly TheoryData, IntMergePoint[], Action, int> BufferizeTileGraphDatas = new() + { + { FunctionSamples.Get1, [new(1, 0, 2)], (bufGraph) => { Assert.Equal(4, bufGraph.Clusters.OfType().First().Edges.Count()); }, 0 }, }; public UnitTestTileGraph() { - CompileOptions.TargetOptions = new Nncase.Targets.CpuTargetOptions(); + CompileOptions.TargetOptions = new Targets.CpuTargetOptions(); +#if DEBUG + CompileOptions.DumpFlags = Diagnostics.DumpFlags.Tiling; +#endif } [Fact] @@ -285,12 +308,12 @@ public void TestMergeTileGraph(Func functor, (IntMergePoint, bool)[] m { var (point, excepted) = mergePoints[i]; Assert.Equal(excepted, tileGraph.Merge(new(tileGraph.Vertices.Skip(point.Consumer).First(), tileGraph.Vertices.Skip(point.Producer).First(), point.Level))); -#if DEBUG if (excepted) { +#if DEBUG tileGraph.Dump($"g{count}_m{i}"); - } #endif + } } checker(tileGraph); @@ -298,10 +321,25 @@ public void TestMergeTileGraph(Func functor, (IntMergePoint, bool)[] m [Theory] [MemberData(nameof(SolveTileGraphDatas))] - public void TestSolveTileGraph(Func functor, Action action, int count) + public void TestSolveTileGraph(Func functor, IntMergePoint[] mergePoints, Action action, int count) { + var targetOptions = (ICpuTargetOptions)CompileOptions.TargetOptions; var func = functor(); - var post = CompilerServices.Rewrite(func, new IRewriteRule[] { new Passes.Rules.CPU.Affine.LowerPack(), new Passes.Rules.CPU.Affine.LowerUnary(), new Passes.Rules.CPU.Affine.LowerMatmul(), new Passes.Rules.CPU.Affine.LowerBinary() }, new()); + var post = CompilerServices.Rewrite(func, [new Passes.Rules.CPU.Affine.LowerPack(), new Passes.Rules.CPU.Affine.LowerUnary(), new Passes.Rules.CPU.Affine.LowerMatmul(), new Passes.Rules.CPU.Affine.LowerBinary()], new()); + + using var dumpScope = new Diagnostics.DumpScope(count.ToString()); + var builder = new GraphBuilder(targetOptions.MemoryBandWidths.Length); + builder.Visit(post); + var tileGraph = builder.RootGraph; + + for (int i = 0; i < mergePoints.Length; i++) + { + var point = mergePoints[i]; + tileGraph.Merge(new(tileGraph.Vertices.Skip(point.Consumer).First(), tileGraph.Vertices.Skip(point.Producer).First(), point.Level)); + } +#if DEBUG + tileGraph.Dump($"g{count}_m"); +#endif var tiler = new Schedule.GraphTiler(); using var scope = new Diagnostics.DumpScope($"{count}"); @@ -309,6 +347,52 @@ public void TestSolveTileGraph(Func functor, Action action, int action(result); } + [Theory] + [MemberData(nameof(MCTSDatas))] + public void TestMCTS(Func functor, int count) + { + var targetOptions = (ICpuTargetOptions)CompileOptions.TargetOptions; + var func = functor(); + var post = CompilerServices.Rewrite(func, [new Passes.Rules.CPU.Affine.LowerPack(), new Passes.Rules.CPU.Affine.LowerUnary(), new Passes.Rules.CPU.Affine.LowerMatmul(), new Passes.Rules.CPU.Affine.LowerBinary()], new()); + + using var dumpScope = new Diagnostics.DumpScope(count.ToString()); + var builder = new GraphBuilder(targetOptions.MemoryBandWidths.Length); + builder.Visit(post); + var tileGraph = builder.RootGraph; + + var memo = new Dictionary(new ITreeNodeComparer()); + var state = new MCTState(tileGraph, "cpu", count.ToString(), string.Empty, memo, targetOptions); + var rootNode = new MCTNode(state); + var searcher = new MCTSearcher(); + searcher.Search(rootNode); + rootNode.Dump("mct"); + } + + [Theory] + [MemberData(nameof(BufferizeTileGraphDatas))] + public void TestBufferizeTileGraph(Func functor, IntMergePoint[] mergePoints, Action action, int count) + { + var targetOptions = (ICpuTargetOptions)CompileOptions.TargetOptions; + var func = functor(); + var post = CompilerServices.Rewrite(func, [new Passes.Rules.CPU.Affine.LowerPack(), new Passes.Rules.CPU.Affine.LowerUnary(), new Passes.Rules.CPU.Affine.LowerMatmul(), new Passes.Rules.CPU.Affine.LowerBinary()], new()); + + using var dumpScope = new Diagnostics.DumpScope(count.ToString()); + var builder = new GraphBuilder(targetOptions.MemoryBandWidths.Length); + builder.Visit(post); + var tileGraph = builder.RootGraph; + + for (int i = 0; i < mergePoints.Length; i++) + { + var point = mergePoints[i]; + tileGraph.Merge(new(tileGraph.Vertices.Skip(point.Consumer).First(), tileGraph.Vertices.Skip(point.Producer).First(), point.Level)); + } +#if DEBUG + tileGraph.Dump($"g{count}_m"); +#endif + var bufferGraph = tileGraph.Bufferize(); + action(bufferGraph[tileGraph]); + } + [Fact] public void TestPrimTreeEqualityComparer() {