Skip to content

Commit

Permalink
Merge branch 'dev/3.0' into feature/refactor_graph_partition
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 authored Jan 14, 2025
2 parents 43f9b24 + 4b6966b commit c52f56e
Show file tree
Hide file tree
Showing 24 changed files with 976 additions and 281 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

#include <nncase/ntt/runtime.h>
#include "topo_aware_runtime.h"
#include "../device.h"
@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).Skip(1).SkipLast(1)){
@:uint8_t L@(i)Data[@(s)];
@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).SkipLast(1)){
@:uint8_t L@(i+1)Data[@(s)];
}
#include "../device.h"
#include "kernel.h"

//alignas(@(Model.Alignment)) static thread_local uint8_t local_data[@(Model.DataSize)];
Expand Down
1 change: 1 addition & 0 deletions modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Nncase.TIR.CPU;

[ParameterInPlace(0, 1)]
public sealed partial class Unary : CPUKernelOp
{
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
Expand Down
14 changes: 14 additions & 0 deletions src/Nncase.Core/IR/Op.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ public enum ParameterKind : int
Attribute,
}

[AttributeUsage(System.AttributeTargets.Class, Inherited = false, AllowMultiple = true)]
public sealed class ParameterInPlaceAttribute : System.Attribute
{
public ParameterInPlaceAttribute(int sourceIndex, int destIndex)
{
SourceIndex = sourceIndex;
DestIndex = destIndex;
}

public int SourceIndex { get; }

public int DestIndex { get; }
}

/// <summary>
/// Parameter information.
/// </summary>
Expand Down
9 changes: 9 additions & 0 deletions src/Nncase.Core/TIR/TIRExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ namespace Nncase.TIR;
/// </summary>
public static class TIRExtensions
{
/// <summary>
/// Get the tir op buffer allocation reuse information.
/// </summary>
/// <returns> map dest index to source index. </returns>
public static Dictionary<int, int> GetInPlaceMemo(this Op op)
{
return op.GetType().GetCustomAttributes(typeof(ParameterInPlaceAttribute), true).OfType<ParameterInPlaceAttribute>().ToDictionary(a => a.DestIndex, a => a.SourceIndex);
}

/// <summary>
/// convert IEnumerable to tir Sequential.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected override void InternalCompute()

var dfs = new DepthFirstSearchAlgorithm<TVertex, TEdge>(this, VisitedGraph, new Dictionary<TVertex, GraphColor>(VisitedGraph.VertexCount));
dfs.TreeEdge += TreeEdge;
dfs.ForwardOrCrossEdge += TreeEdge;
dfs.Compute();
}

Expand Down
370 changes: 232 additions & 138 deletions src/Nncase.Schedule/Schedule/GraphTiler.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;

namespace Nncase.Schedule.MonteCarloTreeSearch;

public interface IEnvironmentState<TAction>
where TAction : class
{
int LegalActions();

TAction GetNextAction(int index);

IEnvironmentState<TAction>? PerformAction(TAction action);

double RollOut();
}
44 changes: 44 additions & 0 deletions src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/SearchNode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;

namespace Nncase.Schedule.MonteCarloTreeSearch;

public abstract class SearchNode<T>
where T : class
{
public SearchNode(IEnvironmentState<T> state)
{
Parent = null;
Children = new List<SearchNode<T>>();
VisitTimes = 0;
QualityValue = 0.0;
State = state;
}

public SearchNode(SearchNode<T> parent, IEnvironmentState<T> state)
{
Parent = parent;
Children = new List<SearchNode<T>>();
VisitTimes = 0;
QualityValue = 0.0;
State = state;
Parent.Children.Add(this);
}

public SearchNode<T>? Parent { get; }

public List<SearchNode<T>> Children { get; }

public int VisitTimes { get; set; }

public double QualityValue { get; set; }

public IEnvironmentState<T> State { get; }

public bool IsRootNode => Parent is null;

public abstract void Update(double reward);
}
47 changes: 47 additions & 0 deletions src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/Searcher.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;

namespace Nncase.Schedule.MonteCarloTreeSearch;

public abstract class Searcher<T>
where T : class
{
public Searcher(int searchTimes = 20)
{
SearchTimes = searchTimes;
}

public int SearchTimes { get; }

public void Search(SearchNode<T> rootNode)
{
for (int i = 0; i < SearchTimes; i++)
{
if (!Selection(rootNode, out var node))
{
return;
}

var expanded = Expand(node);
if (expanded is not null)
{
BackPropagate(expanded, Simulation(expanded));
}
else
{
BackPropagate(node, double.PositiveInfinity);
}
}
}

public abstract bool Selection(SearchNode<T> node, out SearchNode<T> selected);

public abstract SearchNode<T>? Expand(SearchNode<T> node);

public abstract double Simulation(SearchNode<T> node);

public abstract void BackPropagate(SearchNode<T> node, double reward);
}
5 changes: 4 additions & 1 deletion src/Nncase.Schedule/Schedule/OrToolsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ public static long[] Value(this Assignment sol, IntExpr[] inputs)
var vec = new long[inputs.Length];
for (int i = 0; i < inputs.Length; i++)
{
vec[i] = sol.Value(inputs[i].Var());
if (inputs[i] is not null)
{
vec[i] = sol.Value(inputs[i].Var());
}
}

return vec;
Expand Down
22 changes: 18 additions & 4 deletions src/Nncase.Schedule/Schedule/TileGraph/BufferizationAlgorithm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ private void Visit(TieredTileGraph rootGraph)
{
if (!BufferGraphMemo.TryGetValue(rootGraph, out _))
{
var wrappedGraph = new AdjacencyGraph<BufferIdentity, EquatableTaggedEdge<BufferIdentity, BufferEdgeKind>>();
var wrappedGraph = new AdjacencyGraph<BufferIdentity, EquatableTaggedEdge<BufferIdentity, BufferEdgeKind>>(allowParallelEdges: false);
var rootBufferGraph = new BufferGraph(rootGraph.Level, wrappedGraph);
Visit(rootGraph, rootBufferGraph);
Visit(rootGraph, rootBufferGraph, rootGraph);
foreach (var edge in rootGraph.Edges)
{
var source = new BufferIdentity(edge.Source, edge.Source.ReadAccesses.Length);
Expand All @@ -77,12 +77,14 @@ private void Visit(TieredTileGraph rootGraph)
}
}

private void Visit(TieredTileGraph graph, BufferGraph bufferGraph)
private HashSet<TileGrid> Visit(TieredTileGraph graph, BufferGraph bufferGraph, TieredTileGraph rootGraph)
{
var opnodes = new HashSet<TileGrid>();
if (graph.ClustersCount == 0)
{
foreach (var item in graph.Vertices)
{
opnodes.Add(item);
var outBid = new BufferIdentity(item, item.ReadAccesses.Length);
for (int i = 0; i < item.ReadAccesses.Length; i++)
{
Expand All @@ -97,10 +99,22 @@ private void Visit(TieredTileGraph graph, BufferGraph bufferGraph)
if (!BufferGraphMemo.TryGetValue(graph, out _))
{
var childBufferGraph = bufferGraph.CreateCluster<BufferGraph>(childGraph.Level, childGraph.OpId);
Visit(childGraph, childBufferGraph);
opnodes.UnionWith(Visit(childGraph, childBufferGraph, rootGraph));
BufferGraphMemo.Add(childGraph, childBufferGraph);
}
}

foreach (var edge in rootGraph.Edges)
{
if (opnodes.Contains(edge.Source) && opnodes.Contains(edge.Target))
{
var source = new BufferIdentity(edge.Source, edge.Source.ReadAccesses.Length);
var target = new BufferIdentity(edge.Target, edge.Tag);
bufferGraph.AddEdge(new(source, target, BufferEdgeKind.Outer));
}
}
}

return opnodes;
}
}
31 changes: 31 additions & 0 deletions src/Nncase.Schedule/Schedule/TileGraph/GraphExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,37 @@ public static bool Merge(this TieredTileGraph graph, MergePoint mergePoint)
return merger.Visit(graph);
}

public static List<MergePoint> GetMergePoints(this TieredTileGraph graph)
{
var mergePoints = new List<MergePoint>();
if (graph.Level != -1)
{
throw new InvalidOperationException("only can merge at top level!");
}

var children = graph.Clusters.OfType<TieredTileGraph>().ToArray();
foreach (var producer in children)
{
foreach (var comsumer in children)
{
if (ReferenceEquals(producer, comsumer))
{
continue;
}

foreach (var edge in graph.Edges)
{
if (comsumer.ContainsVertex(edge.Source) && producer.ContainsVertex(edge.Target))
{
mergePoints.Add(new(edge.Target, edge.Source, producer.Level));
}
}
}
}

return mergePoints;
}

public static void Walk(this TieredTileGraph graph, Action<ITileable> func, bool postOrder = false)
{
if (!postOrder)
Expand Down
Loading

0 comments on commit c52f56e

Please sign in to comment.