Skip to content

Commit

Permalink
fix cycle
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Dec 31, 2024
1 parent f1bc58f commit 2b31ad5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 45 deletions.
62 changes: 22 additions & 40 deletions modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,18 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassCo

internal sealed class SearchableNode
{
public SearchableNode(Expr expr, IRType type)
public SearchableNode(Expr expr, IRType type, bool isBidirect = false)
{
Expr = expr;
IRType = type;
IsBidirect = isBidirect;
}

public Expr Expr { get; }

public IRType IRType { get; }

public bool IsBidirect { get; }
}

internal sealed record CrossEdge : IEdge<SearchableNode>
Expand Down Expand Up @@ -343,7 +346,7 @@ protected override Unit VisitLeafCall(Call expr)
{
if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType)
{
var rnode = new SearchableNode(new Boxing(rType), rType);
var rnode = new SearchableNode(new Boxing(rType), rType, true);
rBucket.AddVertex(rnode);
callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket));
}
Expand All @@ -360,7 +363,6 @@ protected override Unit VisitLeafCall(Call expr)
bucket = callCluster.CreateCluster<DistributedSearchGraph>(SearchGraphKind.Bucket);
var node = new SearchableNode(new Boxing(nType), nType);
bucket.AddVertex(node);

var linked = false;
foreach (var addedBucket in addedBuckets)
{
Expand Down Expand Up @@ -389,6 +391,9 @@ protected override Unit VisitLeafCall(Call expr)
return default;
}

/// <summary>
/// some times we didn't use all args.
/// </summary>
private IEnumerable<(Call Call, bool[] Used)> BuildEquivalentCalls(Expr target, Expr[] tempArgs)
{
IEnumerable<(Call Call, bool[] Used)> calls = [(new Call(target, tempArgs), Enumerable.Repeat(true, tempArgs.Length).ToArray())];
Expand All @@ -402,6 +407,15 @@ protected override Unit VisitLeafCall(Call expr)
return calls;
}

private IReadOnlyList<IRArray<SBP>> GetDiverseCandidateSBPs(DistributedType distributedType, IEnumerable<Placement> placements)
{
return placements.Select(
placement =>
DistributedUtility.GetLeafCandidateNDSBPs(distributedType.TensorType, placement).
Where(ndsbp => ndsbp != distributedType.NdSBP)).
SelectMany(e => e).ToArray();
}

private DistributedSearchGraph VisitLeafArgument(ParameterKind parameterKind, Expr expr, bool isSupported)
{
DistributedSearchGraph argCluster;
Expand Down Expand Up @@ -739,49 +753,17 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster)
}

// 3. no cycle
foreach (var cluster in _rootSearchGraph.Clusters.OfType<DistributedSearchGraph>())
{
var hgraph = ToHyperGraph(_rootSearchGraph, rootCluster);
var class_cycles = hgraph.FindCycles();
foreach (var cycle in class_cycles)
foreach (var sourceBucket in cluster.Clusters.OfType<DistributedSearchGraph>())
{
if (cycle.Count == 1)
foreach (var destBucket in cluster.Clusters.OfType<DistributedSearchGraph>().Where(b => !ReferenceEquals(b, sourceBucket)))
{
foreach (var n in cycle[0].Vertices)
foreach (var (src, dest) in sourceBucket.Vertices.Where(v => v.IsBidirect).Zip(destBucket.Vertices.Where(v => v.IsBidirect)))
{
_rootSearchGraph.TryGetOutEdges(n, out var edgs);
if (edgs.Select(e => e.InputGraph).Contains(cycle[0]))
{
cpmodel.AddAssumption(varMemo[n].Not());
}
cpmodel.AddBoolAnd([varMemo[src].Not(), varMemo[dest].Not()]);
}
}
else
{
// build clauses.
var clauses = new List<List<BoolVar>>();
for (int i = 0; i < cycle.Count; i++)
{
var next_hop = (i + 1) % cycle.Count;
var u = hgraph.Edges(cycle[i])!;
var v = u[cycle[next_hop]];
clauses.Add(v.Select(n => varMemo[n]).ToList());
}

var clauseMemo = new Dictionary<int, BoolVar>();
for (int i = 0; i < clauses.Count; i++)
{
var clause = clauses[i];
if (clause.Count > 1)
{
var tmpV = cpmodel.NewBoolVar(string.Empty);
cpmodel.AddBoolAnd(clause.Select(c => c.Not())).OnlyEnforceIf(tmpV);
cpmodel.AddBoolOr(clause).OnlyEnforceIf(tmpV.Not());
clauseMemo.Add(i, tmpV);
}
}

cpmodel.AddBoolOr(clauses.Select((c, i) => (c, i)).Select(p => p.c.Count == 1 ? p.c[0].Not() : clauseMemo[p.i]));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public sealed partial class SwapUnpackReduce : IRewriteRule
public IPattern Pattern { get; } =
IsBoxing(
target_name: "boxing",
op => op.NewType is DistributedType dt && dt.NdSBP.All(s => s != SBP.P),
op => op.NewType is DistributedType dt && dt.NdSBP.All(s => s is not SBPPartial),
IsUnpack(
target_name: "unpack",
_ => true,
Expand All @@ -54,10 +54,10 @@ public sealed partial class SwapUnpackReduce : IRewriteRule

public Expr? GetReplace(Call call, Boxing boxing, Unpack unpack)
{
if (call.CheckedType is DistributedType dt && dt.NdSBP.Any(s => s == SBP.P))
if (call.CheckedType is DistributedType dt && dt.NdSBP.Any(s => s is SBPPartial))
{
var newType = new DistributedType(dt.TensorType, dt.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), dt.Placement);
var newBoxing = IR.F.CPU.Boxing(call, newType, boxing.IsReshape);
var newType = new DistributedType(dt.TensorType, dt.NdSBP.Select(s => s is SBPPartial ? SBP.B : s).ToArray(), dt.Placement);
var newBoxing = IR.F.CPU.Boxing(call, newType);
return IR.F.CPU.Unpack(newBoxing, [.. unpack.Lanes], [.. unpack.Axes]);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ public async Task TestTranspose(int[] shape, int[] perm, int rank, int number)
}

[Theory]
[InlineData([new[] { 4 }, 0])]
[InlineData([new[] { 2, 4 }, 0])]
public async Task TestTransposeMatmul(int[] hierarchy, int number)
{
var targetOptions = (CpuTargetOptions)CompileOptions.TargetOptions;
Expand Down

0 comments on commit 2b31ad5

Please sign in to comment.