From 2b31ad57afccb4a164c9a24bb734156e8232df95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 31 Dec 2024 15:38:21 +0800 Subject: [PATCH] fix cycle --- .../Passes/Distributed/AutoDistributed.cs | 62 +++++++------------ .../Passes/Rules/CPU/FoldMatmulReduce.cs | 8 +-- .../Targets/UnitTestCPUKernels.cs | 2 +- 3 files changed, 27 insertions(+), 45 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 862624996..23e3452ad 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -73,15 +73,18 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo internal sealed class SearchableNode { - public SearchableNode(Expr expr, IRType type) + public SearchableNode(Expr expr, IRType type, bool isBidirect = false) { Expr = expr; IRType = type; + IsBidirect = isBidirect; } public Expr Expr { get; } public IRType IRType { get; } + + public bool IsBidirect { get; } } internal sealed record CrossEdge : IEdge @@ -343,7 +346,7 @@ protected override Unit VisitLeafCall(Call expr) { if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType) { - var rnode = new SearchableNode(new Boxing(rType), rType); + var rnode = new SearchableNode(new Boxing(rType), rType, true); rBucket.AddVertex(rnode); callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket)); } @@ -360,7 +363,6 @@ protected override Unit VisitLeafCall(Call expr) bucket = callCluster.CreateCluster(SearchGraphKind.Bucket); var node = new SearchableNode(new Boxing(nType), nType); bucket.AddVertex(node); - var linked = false; foreach (var addedBucket in addedBuckets) { @@ -389,6 +391,9 @@ protected override Unit VisitLeafCall(Call expr) return default; } + /// + /// some times we didn't use all args. + /// private IEnumerable<(Call Call, bool[] Used)> BuildEquivalentCalls(Expr target, Expr[] tempArgs) { IEnumerable<(Call Call, bool[] Used)> calls = [(new Call(target, tempArgs), Enumerable.Repeat(true, tempArgs.Length).ToArray())]; @@ -402,6 +407,15 @@ protected override Unit VisitLeafCall(Call expr) return calls; } + private IReadOnlyList> GetDiverseCandidateSBPs(DistributedType distributedType, IEnumerable placements) + { + return placements.Select( + placement => + DistributedUtility.GetLeafCandidateNDSBPs(distributedType.TensorType, placement). + Where(ndsbp => ndsbp != distributedType.NdSBP)). + SelectMany(e => e).ToArray(); + } + private DistributedSearchGraph VisitLeafArgument(ParameterKind parameterKind, Expr expr, bool isSupported) { DistributedSearchGraph argCluster; @@ -739,49 +753,17 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) } // 3. no cycle + foreach (var cluster in _rootSearchGraph.Clusters.OfType()) { - var hgraph = ToHyperGraph(_rootSearchGraph, rootCluster); - var class_cycles = hgraph.FindCycles(); - foreach (var cycle in class_cycles) + foreach (var sourceBucket in cluster.Clusters.OfType()) { - if (cycle.Count == 1) + foreach (var destBucket in cluster.Clusters.OfType().Where(b => !ReferenceEquals(b, sourceBucket))) { - foreach (var n in cycle[0].Vertices) + foreach (var (src, dest) in sourceBucket.Vertices.Where(v => v.IsBidirect).Zip(destBucket.Vertices.Where(v => v.IsBidirect))) { - _rootSearchGraph.TryGetOutEdges(n, out var edgs); - if (edgs.Select(e => e.InputGraph).Contains(cycle[0])) - { - cpmodel.AddAssumption(varMemo[n].Not()); - } + cpmodel.AddBoolAnd([varMemo[src].Not(), varMemo[dest].Not()]); } } - else - { - // build clauses. - var clauses = new List>(); - for (int i = 0; i < cycle.Count; i++) - { - var next_hop = (i + 1) % cycle.Count; - var u = hgraph.Edges(cycle[i])!; - var v = u[cycle[next_hop]]; - clauses.Add(v.Select(n => varMemo[n]).ToList()); - } - - var clauseMemo = new Dictionary(); - for (int i = 0; i < clauses.Count; i++) - { - var clause = clauses[i]; - if (clause.Count > 1) - { - var tmpV = cpmodel.NewBoolVar(string.Empty); - cpmodel.AddBoolAnd(clause.Select(c => c.Not())).OnlyEnforceIf(tmpV); - cpmodel.AddBoolOr(clause).OnlyEnforceIf(tmpV.Not()); - clauseMemo.Add(i, tmpV); - } - } - - cpmodel.AddBoolOr(clauses.Select((c, i) => (c, i)).Select(p => p.c.Count == 1 ? p.c[0].Not() : clauseMemo[p.i])); - } } } diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs index a811e1efa..d99c28473 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldMatmulReduce.cs @@ -41,7 +41,7 @@ public sealed partial class SwapUnpackReduce : IRewriteRule public IPattern Pattern { get; } = IsBoxing( target_name: "boxing", - op => op.NewType is DistributedType dt && dt.NdSBP.All(s => s != SBP.P), + op => op.NewType is DistributedType dt && dt.NdSBP.All(s => s is not SBPPartial), IsUnpack( target_name: "unpack", _ => true, @@ -54,10 +54,10 @@ public sealed partial class SwapUnpackReduce : IRewriteRule public Expr? GetReplace(Call call, Boxing boxing, Unpack unpack) { - if (call.CheckedType is DistributedType dt && dt.NdSBP.Any(s => s == SBP.P)) + if (call.CheckedType is DistributedType dt && dt.NdSBP.Any(s => s is SBPPartial)) { - var newType = new DistributedType(dt.TensorType, dt.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), dt.Placement); - var newBoxing = IR.F.CPU.Boxing(call, newType, boxing.IsReshape); + var newType = new DistributedType(dt.TensorType, dt.NdSBP.Select(s => s is SBPPartial ? SBP.B : s).ToArray(), dt.Placement); + var newBoxing = IR.F.CPU.Boxing(call, newType); return IR.F.CPU.Unpack(newBoxing, [.. unpack.Lanes], [.. unpack.Axes]); } diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs index d153f8d98..87944d01b 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs @@ -535,7 +535,7 @@ public async Task TestTranspose(int[] shape, int[] perm, int rank, int number) } [Theory] - [InlineData([new[] { 4 }, 0])] + [InlineData([new[] { 2, 4 }, 0])] public async Task TestTransposeMatmul(int[] hierarchy, int number) { var targetOptions = (CpuTargetOptions)CompileOptions.TargetOptions;