Skip to content

Commit

Permalink
fix graph partition
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 15, 2025
1 parent aad5731 commit 08302f3
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 18 deletions.
6 changes: 3 additions & 3 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public record BufferRenderInfo(string Name, string ElemType, ulong Offset, ulong
{
}

public record KernelMainModel(TIR.PrimFunction PrimFunction, TIR.Buffer[] RDataBuffers, CpuTargetOptions Options, ulong Alignment, ulong DataSize, ulong RDataSize)
public record KernelMainModel(TIR.PrimFunction PrimFunction, TIR.Buffer[] RDataBuffers, CpuTargetOptions Options, ulong Alignment, ulong DataSize, ulong RDataSize, ulong LocalRdataPoolSize)
{
public BufferRenderInfo GetInfo(TIR.Buffer buffer)
{
Expand Down Expand Up @@ -64,9 +64,9 @@ public static string CMakeDef(string name)
return content;
}

public static string MakeMain(TIR.PrimFunction primFunction, ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, IEnumerable<TIR.Buffer> rdataBuffers, CpuTargetOptions options)
public static string MakeMain(TIR.PrimFunction primFunction, ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, ulong localRdataPoolSize, IEnumerable<TIR.Buffer> rdataBuffers, CpuTargetOptions options)
{
var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/thread_main.cpp.cshtml", new KernelMainModel(primFunction, rdataBuffers.ToArray(), options, dataAlign, dataUsage, rdataPoolSize)).Result;
var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/thread_main.cpp.cshtml", new KernelMainModel(primFunction, rdataBuffers.ToArray(), options, dataAlign, dataUsage, rdataPoolSize, localRdataPoolSize)).Result;
return content;
}

Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public unsafe ILinkableFunction Build(TIR.PrimFunction function)
}

// 3. build function.
var visitor = new KernelCSourceConvertVisitor(function.SchedResult.DataAlign, function.SchedResult.DataUsage, rdataPoolSize, TargetOptions);
var visitor = new KernelCSourceConvertVisitor(function.SchedResult.DataAlign, function.SchedResult.DataUsage, rdataPoolSize, localRdataPoolSize, TargetOptions);
visitor.Visit(function);
var functionCSource = visitor.GetCSource();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ internal sealed class KernelCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>,
private readonly StringWriter _sharedWriter;
private ulong _collective_pool_size;

public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, CpuTargetOptions targetOptions)
public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, ulong localRdataPoolSize, CpuTargetOptions targetOptions)
{
DataAlign = dataAlign;
DataUsage = dataUsage;
RdataPoolSize = rdataPoolSize;
LocalRdataPoolSize = localRdataPoolSize;
_kernelBuilder = new StringBuilder();
_sharedBuilder = new StringBuilder();
_sharedWriter = new StringWriter(_sharedBuilder);
Expand All @@ -145,11 +146,13 @@ public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdata

public ulong RdataPoolSize { get; }

public ulong LocalRdataPoolSize { get; }

public KernelCSource GetCSource()
{
var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* data)";
return new(
CSourceBuiltn.MakeMain(VisitEntry, DataAlign, DataUsage, RdataPoolSize, _exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata), TargetOptions),
CSourceBuiltn.MakeMain(VisitEntry, DataAlign, DataUsage, RdataPoolSize, LocalRdataPoolSize, _exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata), TargetOptions),
CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()),
CSourceBuiltn.TopoAwareRuntimeDef(TargetOptions, DataAlign, _collective_pool_size),
CSourceBuiltn.TopologyDef(TargetOptions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
}

std::byte* rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.RDataSize, align);
std::byte* local_rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.LocalRdataPoolSize, align);
uint64_t local_rdata_header[@Model.Options.Hierarchies[0][^1] * 2];
for (size_t tid = 0; tid < tdim(); tid++) {
local_rdata_header[tid * 2] = tid * ( @Model.LocalRdataPoolSize / tdim());
}

#ifdef __APPLE__
pthread_key_t cpu_thread_context_key_ = {};
Expand All @@ -73,7 +78,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
std::vector<std::thread> blocks;
for (size_t cid = 0; cid < cdim(); cid++) {
for (size_t bid = 0; bid < bdim(); bid++) {
blocks.emplace_back([cid, bid, inputs, rdata
blocks.emplace_back([cid, bid, inputs, rdata, local_rdata_header, local_rdata
#ifdef __APPLE__
, &cpu_thread_context_key_
#endif
Expand All @@ -87,6 +92,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
.cpu_id_offset = (cid * bdim() + bid) * tdim(),
.inouts = inputs,
.rdata = rdata,
.local_rdata_header = local_rdata_header,
.local_rdata = local_rdata,
#ifdef __APPLE__
.cpu_thread_context_key = cpu_thread_context_key_,
#endif
Expand Down
7 changes: 6 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Nncase.Evaluator.TIR.CPU;

public sealed class BinaryEvaluator : ITypeInferencer<Binary>, IKernelInfoEvaluator<Binary>
public sealed class BinaryEvaluator : ITypeInferencer<Binary>, IKernelInfoEvaluator<Binary>, IOpPrinter<Binary>
{
public IRType Visit(ITypeInferenceContext context, Binary target)
{
Expand All @@ -29,6 +29,11 @@ public MicroKernelInfo Visit(Binary op, MicroKernelContext context)
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

public string Visit(IIRPrinterContext context, Binary target, bool iLmode)
{
return $"Binary({target.DisplayProperty()}, {context.GetArgument(target, Binary.Lhs)}, {context.GetArgument(target, Binary.Rhs)}, {context.GetArgument(target, Binary.Output)})";
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factora = System.Math.Min(context.BufferShapes[0][^1], 32);
Expand Down
7 changes: 6 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Nncase.Evaluator.TIR.CPU;

public sealed class UnaryEvaluator : ITypeInferencer<Unary>, IKernelInfoEvaluator<Unary>
public sealed class UnaryEvaluator : ITypeInferencer<Unary>, IKernelInfoEvaluator<Unary>, IOpPrinter<Unary>
{
public IRType Visit(ITypeInferenceContext context, Unary target)
{
Expand All @@ -31,6 +31,11 @@ public MicroKernelInfo Visit(Unary op, MicroKernelContext context)
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

public string Visit(IIRPrinterContext context, Unary target, bool iLmode)
{
return $"Unary({target.DisplayProperty()}, {context.GetArgument(target, Unary.Input)}, {context.GetArgument(target, Unary.Output)})";
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,14 @@ protected override string VisitBufferOf(BufferOf expr)
/// <inheritdoc/>
protected override string VisitGrid(IR.Affine.Grid expr)
{
var reads = expr.Reads.AsValueEnumerable().Select(Visit).ToArray();
var buffers = expr.Buffers.AsValueEnumerable().Select(Visit).ToArray();
if (_names.TryGetValue(expr, out var name))
{
return name;
}

name = AllocateTempVar(expr);
var reads = expr.Reads.AsValueEnumerable().Select(Visit).ToArray();
var buffers = expr.Buffers.AsValueEnumerable().Select(Visit).ToArray();
_scope.Push();

// 1. For Loop signature
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Evaluator/Tensors/GetItem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private IValue Visit(IValue input, IValue index)

private IRType Visit(ITypeInferenceContext context, GetItem target, IRType input, TensorType index)
{
IRType ret = new InvalidType("Need Be Reset!");
IRType ret = new InvalidType("GetItem typeinfer error!");
var indexExpr = context.GetArgument(target, GetItem.Index);
switch (input)
{
Expand Down
19 changes: 19 additions & 0 deletions src/Nncase.Graph/Graphs/GraphExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,26 @@ public static class GraphExtensions
public static IEnumerable<TEdge> InEdges<TVertex, TEdge>(this IBidirectionalGraph<TVertex, TEdge> subGraph, IBidirectionalGraph<TVertex, TEdge> parentGraph)
where TEdge : IEdge<TVertex> => subGraph.Vertices.Select(v => parentGraph.InEdges(v).Except(subGraph.InEdges(v))).SelectMany(e => e);

public static IEnumerable<TEdge> OutEdges<TVertex, TEdge>(this IBidirectionalGraph<TVertex, TEdge> subGraph, IBidirectionalGraph<TVertex, TEdge> parentGraph)
where TEdge : IEdge<TVertex> => subGraph.Vertices.Select(v => parentGraph.OutEdges(v)).SelectMany(e => e).Where(e => !subGraph.ContainsVertex(e.Target));

public static IEnumerable<TVertex> InVertices<TVertex, TEdge>(this IBidirectionalGraph<TVertex, TEdge> graph)
where TEdge : IEdge<TVertex>
=> graph.Vertices.Where(v => graph.InDegree(v) == 0);

public static IEnumerable<TVertex> OutVertices<TVertex, TEdge>(this IBidirectionalGraph<TVertex, TEdge> graph)
where TEdge : IEdge<TVertex>
=> graph.Vertices.Where(v => graph.OutDegree(v) == 0);

public static IEnumerable<TVertex> OutVertices<TVertex, TEdge>(this IBidirectionalGraph<TVertex, TEdge> subGraph, IBidirectionalGraph<TVertex, TEdge> parentGraph)
where TEdge : IEdge<TVertex>
{
var outEdges = OutEdges(subGraph, parentGraph).ToArray();
if (outEdges.Length == 0)
{
return OutVertices(subGraph);
}

return outEdges.DistinctBy(e => e.Source).Select(e => e.Source);
}
}
11 changes: 6 additions & 5 deletions src/Nncase.Passes/GraphPartition/ExprReConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,24 @@ public Expr Construct()
return ClusterMemo[dfsVisitor.SortedVertices[^1]];
}

protected IEnumerable<(Expr Pre, Expr Post)> GetClusterArgumentPairs(ClusteredBidirectionalGraph<TVertex, TEdge> cluster)
protected virtual IEnumerable<(Expr Pre, Expr Post)> GetClusterArgumentPairs(ClusteredBidirectionalGraph<TVertex, TEdge> cluster)
{
var pairs = new List<(Expr Pre, Expr Post)>();
foreach (var inEdge in cluster.InEdges(Algo.ClusteredGraph))
{
// get in Expr
Expr postArg;
var sourceCluster = Algo.VertexMap[inEdge.Source];
var sourcerOutVertices = sourceCluster.OutVertices().ToArray();
if (sourcerOutVertices.Length == 1)
var sourceOutVertices = sourceCluster.OutVertices(Algo.ClusteredGraph).ToArray();
if (sourceOutVertices.Length == 1)
{
postArg = ClusterMemo[sourceCluster];
}
else
{
var sourceOutIndex = sourcerOutVertices.IndexOf(inEdge.Source);
postArg = IR.F.Tensors.GetItem(ClusterMemo[sourceCluster], sourceOutIndex);
var sourceOutIndex = sourceOutVertices.IndexOf(inEdge.Source);
var postResult = ClusterMemo[sourceCluster];
postArg = postResult is IR.Tuple tp ? tp.Fields[sourceOutIndex] : IR.F.Tensors.GetItem(postResult, sourceOutIndex);
}

pairs.Add((inEdge.Source.Expr, postArg));
Expand Down
3 changes: 2 additions & 1 deletion src/Nncase.Schedule/Transforms/AutoTilePass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph<ExprVertex,
}
}

// todo sometimes internal grid have outside dependence, so we can't fuse it when tiling.
var cloner = new ExprClusterCloner(extractDict);
var outVertices = cluster.OutVertices().ToArray();
var outVertices = cluster.OutVertices(Algo.ClusteredGraph).ToArray();
var clones = new List<Expr>();
foreach (var outVertex in outVertices)
{
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Tests/Affine/UnitTestTileGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ public void TestMCTS(Func<Function> functor, int count)
var rootNode = new MCTNode(state);
var searcher = new MCTSearcher();
searcher.Search(rootNode);
#if DEBUG
rootNode.Dump("mct");
#endif
}

[Theory]
Expand Down

0 comments on commit 08302f3

Please sign in to comment.