Skip to content

Commit

Permalink
Fix shape bucket & pad
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jan 5, 2025
1 parent 0cb1eeb commit 10686c5
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 126 deletions.
72 changes: 48 additions & 24 deletions src/Native/src/kernels/stackvm/reference/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,35 @@ template <class T> void set_data_v(T *dst, int len, T value) {
}

template <class T>
void pad_data2(T *in, T *out, int cl, int dl, int hl, int wl, int ch, int dh,
int hh, int wh, T value) {
void pad_data2(T *in, T *out, int nl, int cl, int dl, int hl, int wl, int nh,
int ch, int dh, int hh, int wh, T value) {
(void)ch;
int blocks_in = wl;

int blocks_out = wh;

for (int c = 0; c < cl; ++c) {
for (int d = 0; d < dl; ++d) {
for (int h = 0; h < hl; ++h) {
int index_out = h + d * hh + c * dh * hh;
int index_in = c * hl * dl + d * hl + h;
T *inptr = in + index_in * blocks_in;
T *outptr = out + index_out * blocks_out;
copy_data_v(inptr, outptr, blocks_in, blocks_out, value);
for (int n = 0; n < nl; ++n) {
for (int c = 0; c < cl; ++c) {
for (int d = 0; d < dl; ++d) {
for (int h = 0; h < hl; ++h) {
int index_out = h + d * hh + c * dh * hh + n * ch * dh * hh;
int index_in = n * hl * dl * cl + c * hl * dl + d * hl + h;
T *inptr = in + index_in * blocks_in;
T *outptr = out + index_out * blocks_out;
copy_data_v(inptr, outptr, blocks_in, blocks_out, value);
}
}
}
}
for (int c = 0; c < ch; ++c) {
for (int d = 0; d < dh; ++d) {
for (int h = 0; h < hh; ++h) {
int index = h + d * hh + c * dh * hh;
T *outptr = out + index * blocks_out;
if (h >= hl || d >= dl || c >= cl) {
set_data_v(outptr, blocks_out, value);
for (int n = 0; n < nh; ++n) {
for (int c = 0; c < ch; ++c) {
for (int d = 0; d < dh; ++d) {
for (int h = 0; h < hh; ++h) {
int index = h + d * hh + c * dh * hh + n * ch * dh * hh;
T *outptr = out + index * blocks_out;
if (h >= hl || d >= dl || c >= cl) {
set_data_v(outptr, blocks_out, value);
}
}
}
}
Expand All @@ -150,48 +154,67 @@ void pad_data2(T *in, T *out, int cl, int dl, int hl, int wl, int ch, int dh,
template <class T>
void padding_impl_opt(T *in, T *out, gsl::span<const size_t> in_shape,
gsl::span<const size_t> out_shape, T value) {
int cl, dl, hl, wl;
int ch, dh, hh, wh;
int nl, cl, dl, hl, wl;
int nh, ch, dh, hh, wh;
if (in_shape.size() == 3 ||
(in_shape.size() == 4 && in_shape[in_shape.size() - 1] == 1)) {
nl = 1;
cl = 1;
dl = in_shape[0];
hl = in_shape[1];
wl = in_shape[2];
nh = 1;
ch = 1;
dh = out_shape[0];
hh = out_shape[1];
wh = out_shape[2];
} else if (in_shape.size() == 5) {
nl = in_shape[0];
cl = in_shape[1];
dl = in_shape[2];
hl = in_shape[3];
wl = in_shape[4];
nh = out_shape[0];
ch = out_shape[1];
dh = out_shape[2];
hh = out_shape[3];
wh = out_shape[4];
} else if (in_shape.size() == 4) {
nl = 1;
cl = in_shape[0];
dl = in_shape[1];
hl = in_shape[2];
wl = in_shape[3];
nh = 1;
ch = out_shape[0];
dh = out_shape[1];
hh = out_shape[2];
wh = out_shape[3];
} else if (in_shape.size() == 2) {
nl = 1;
cl = 1;
dl = 1;
hl = in_shape[0];
wl = in_shape[1];
nh = 1;
ch = 1;
dh = 1;
hh = out_shape[0];
wh = out_shape[1];
} else {
nl = 1;
cl = 1;
dl = 1;
hl = 1;
wl = in_shape[0];
nh = 1;
ch = 1;
dh = 1;
hh = 1;
wh = out_shape[1];
}

pad_data2(in, out, cl, dl, hl, wl, ch, dh, hh, wh, value);
pad_data2(in, out, nl, cl, dl, hl, wl, nh, ch, dh, hh, wh, value);
}

template <class T>
Expand Down Expand Up @@ -239,9 +262,10 @@ result<void> nncase::kernels::stackvm::reference::pad(
if (std::all_of(paddings.begin(), paddings.end(),
[](const padding &p) { return p.interior == 0; })) {
auto out_shape = get_padded_shape(in_shape, paddings);
auto can_opt = out_shape.size() < 6;
switch (unit) {
case 1:
if (padding_before_is_zero) {
if (padding_before_is_zero && can_opt) {
padding_impl_opt((int8_t *)input, (int8_t *)output, in_shape,
out_shape, *(int8_t *)pad_value);
} else {
Expand All @@ -252,7 +276,7 @@ result<void> nncase::kernels::stackvm::reference::pad(
}
break;
case 2:
if (padding_before_is_zero) {
if (padding_before_is_zero && can_opt) {
padding_impl_opt((int16_t *)input, (int16_t *)output, in_shape,
out_shape, *(int16_t *)pad_value);
} else {
Expand All @@ -263,7 +287,7 @@ result<void> nncase::kernels::stackvm::reference::pad(
}
break;
case 4:
if (padding_before_is_zero) {
if (padding_before_is_zero && can_opt) {
padding_impl_opt((int32_t *)input, (int32_t *)output, in_shape,
out_shape, *(int32_t *)pad_value);
} else {
Expand All @@ -274,7 +298,7 @@ result<void> nncase::kernels::stackvm::reference::pad(
}
break;
case 8:
if (padding_before_is_zero) {
if (padding_before_is_zero && can_opt) {
padding_impl_opt((int64_t *)input, (int64_t *)output, in_shape,
out_shape, *(int64_t *)pad_value);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/IR/PrimFunctionWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public PrimFunctionWrapper(string name, PrimFunction target, int parametersCount
/// <param name="target">Target.</param>
/// <param name="parametersCount">Arguments count.</param>
public PrimFunctionWrapper(PrimFunction target, int parametersCount)
: this($"func_{_globalFuncIndex++}", target, parametersCount)
: this($"primfunc_wrapper_{_globalFuncIndex++}", target, parametersCount)
{
}

Expand Down
83 changes: 28 additions & 55 deletions src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,56 +22,21 @@ namespace Nncase.Passes.Rules.ShapeBucket;

public class FusionShapeData
{
public FusionShapeData(IValue outshape, IValue[] inputShapes)
public FusionShapeData(IValue outshape, IValue[] inputShapes, IValue?[] inputValues, bool[] inputFromShapes)
{
Outshape = outshape;
InputShapes = inputShapes;
InputValues = inputValues;
InputFromShapes = inputFromShapes;
}

public IValue Outshape { get; }

public IValue[] InputShapes { get; }
}

public class FusionShapeUpdater : ExprVisitor<Expr, Unit>
{
private readonly Dictionary<Expr, IValue> _memo;

public FusionShapeUpdater(Dictionary<Expr, IValue> memo)
{
_memo = memo;
}

public Dictionary<BucketFusion, FusionShapeData> FusionShape { get; } = new();
public IValue?[] InputValues { get; }

protected override Expr DefaultVisitLeaf(Expr expr) => expr;

protected override Expr VisitLeafCall(Call expr)
{
if (expr.Target is BucketFusion f)
{
var argShape = expr.Arguments.ToArray().Select(arg =>
{
var exp = arg is Marker m ? m.Target : arg;
return GetShape(_memo[exp]);
}).ToArray();
var shape = GetShape(_memo[expr]);
FusionShape[f] = new FusionShapeData(shape, argShape);
}

return expr;
}

private IValue GetShape(IValue value)
{
var shapes = value.AsTensors().Select(x => x.Shape.ToValueArray()).ToArray();
if (shapes.Length == 1)
{
return Value.FromTensor(shapes[0]);
}

return new TupleValue(shapes.Select(x => Value.FromTensor(x)).ToArray());
}
public bool[] InputFromShapes { get; }
}

public class FusionShapeUpdater2 : ExprVisitor<Expr, Unit>
Expand All @@ -91,13 +56,14 @@ protected override Expr VisitLeafCall(Call expr)
{
if (expr.Target is BucketFusion f)
{
var argShape = expr.Arguments.ToArray().Select(arg =>
var argData = expr.Arguments.ToArray().Select(arg =>
{
var exp = arg is Marker m ? m.Target : arg;
return GetValueOfShape(_memo[exp].IRType!);
var valueOrShape = _memo[exp];
return (Shape: GetValueOfShape(valueOrShape.IRType!), Value: valueOrShape.Value, FromShape: valueOrShape.FromShape);
}).ToArray();
var shape = GetValueOfShape(_memo[expr].IRType!);
FusionShape[f] = new FusionShapeData(shape, argShape);
FusionShape[f] = new FusionShapeData(shape, argData.Select(x => x.Shape).ToArray(), argData.Select(x => x.Value).ToArray(), argData.Select(x => x.FromShape).ToArray());
}

return expr;
Expand Down Expand Up @@ -239,13 +205,13 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction main, RunPassCon
var f = new FusionShapeUpdater(ConcatDictionary(memo, exprValues));
#else
var input = MakeDummyInputType(varMap, varValues);
var eval = new PartialShapeEvaluator(input.ToDictionary(p => p.Key, p => new ValueOrShape(p.Value, null)), varValues);
var eval = new PartialShapeEvaluator(input.ToDictionary(p => p.Key, p => new ValueOrShape(p.Value, null, false)), varValues);
eval.Visit(body);
var memo = eval.ExprMemo;
foreach (var (k, v) in exprValues)
{
var x = v.AsTensor();
memo.Add(k, new(new TensorType(x.ElementType, x.Shape), v));
memo.Add(k, new(new TensorType(x.ElementType, x.Shape), v, true));
}

var f = new FusionShapeUpdater2(memo);
Expand All @@ -270,7 +236,7 @@ public record ValueOrShape
{
private IValue? _concreteValue;

public ValueOrShape(IRType? irType, IValue? value)
public ValueOrShape(IRType? irType, IValue? value, bool fromShape)
{
if (irType is InvalidType)
{
Expand All @@ -280,6 +246,7 @@ public ValueOrShape(IRType? irType, IValue? value)
IRType = irType;
Value = value;
_concreteValue = null;
FromShape = fromShape;
}

public IRType? IRType { get; }
Expand All @@ -288,6 +255,8 @@ public ValueOrShape(IRType? irType, IValue? value)

public bool HasValue => Value != null;

public bool FromShape { get; }

public IValue Concrete()
{
if (_concreteValue != null)
Expand Down Expand Up @@ -325,9 +294,9 @@ public PartialShapeEvaluator(Dictionary<Var, ValueOrShape> inputDict, Dictionary

protected override ValueOrShape VisitLeafMarker(Marker expr) => Visit(expr.Target);

protected override ValueOrShape VisitLeafBaseFunction(BaseFunction expr) => new(expr.CheckedType, null);
protected override ValueOrShape VisitLeafBaseFunction(BaseFunction expr) => new(expr.CheckedType, null, false);

protected override ValueOrShape VisitLeafOp(Op expr) => new(expr.CheckedType, null);
protected override ValueOrShape VisitLeafOp(Op expr) => new(expr.CheckedType, null, false);

protected override ValueOrShape VisitLeafVar(Var expr)
{
Expand All @@ -337,7 +306,7 @@ protected override ValueOrShape VisitLeafVar(Var expr)
}
else if (DimDict.TryGetValue(expr, out var dimValue) && dimValue is TensorValue dimtv)
{
return new(dimtv.Type, dimtv);
return new(dimtv.Type, dimtv, true);
}
else
{
Expand All @@ -349,8 +318,11 @@ protected override ValueOrShape VisitLeafVar(Var expr)

protected override ValueOrShape VisitLeafTuple(IR.Tuple expr)
{
var value = Value.FromTensors(expr.Fields.AsValueEnumerable().Select(Visit).Select(vs => vs.Concrete().AsTensor()).ToArray());
return new(value.Type, value);
var valueOrShapes = expr.Fields.AsValueEnumerable().Select(Visit).ToArray();
var value = Value.FromTensors(valueOrShapes.Select(vs => vs.Concrete().AsTensor()).ToArray());

// FIX ME: TupleType's from shape is not correct
return new(value.Type, value, valueOrShapes.Any(x => x.FromShape));
}

protected override ValueOrShape VisitLeafCall(Call expr)
Expand All @@ -363,26 +335,27 @@ protected override ValueOrShape VisitLeafCall(Call expr)
{
var shapeArr = ((TensorType)args[0].IRType!).Shape.Select(x => (long)x.FixedValue).ToArray();
var value = Value.FromTensor(Tensor.From<long>(shapeArr));
result = new(value.Type, value);
result = new(value.Type, value, true);
}

break;
case Op op:
{
if (args.All(x => x is { HasValue: true }))
{
var fromShape = args.All(x => x.FromShape);
var tmpCall = new Call(op, args.Select(a => Const.FromValue(a.Concrete())).ToArray());
var ctx = new EvaluateContext(args)
{
CurrentCall = tmpCall,
};
var value = CompilerServices.EvaluateOp(op, ctx);
result = new(value.Type, value);
result = new(value.Type, value, fromShape);
}
else
{
var ctx = new TypeInferenceContext(args);
result = new(CompilerServices.InferenceOp(op, ctx, new()), null);
result = new(CompilerServices.InferenceOp(op, ctx, new()), null, false);
}
}

Expand All @@ -408,7 +381,7 @@ protected override ValueOrShape VisitLeafCall(Call expr)
return result;
}

protected override ValueOrShape VisitLeafConst(Const expr) => new(expr.CheckedType, Value.FromConst(expr));
protected override ValueOrShape VisitLeafConst(Const expr) => new(expr.CheckedType, Value.FromConst(expr), true);
}

internal sealed class EvaluateContext : IEvaluateContext
Expand Down
Loading

0 comments on commit 10686c5

Please sign in to comment.