Skip to content

Commit

Permalink
add shape const
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 16, 2025
1 parent afd5b60 commit 1ca0154
Show file tree
Hide file tree
Showing 25 changed files with 869 additions and 118 deletions.
22 changes: 13 additions & 9 deletions src/Nncase.Core/IR/Const.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,20 @@ public static TensorConst FromTensor(Tensor tensor)
/// <returns>Created constant expression.</returns>
public static Const FromValue(IValue value)
{
if (value is TensorValue tv)
switch (value)
{
return tv.Type is DistributedType distributedType
? new TensorConst(tv.AsTensor(), distributedType.NdSBP, distributedType.Placement)
: new TensorConst(tv.AsTensor());
}
else
{
var tpv = (TupleValue)value;
return new TupleConst(tpv);
case TensorValue tv:
return tv.Type is DistributedType distributedType
? new TensorConst(tv.AsTensor(), distributedType.NdSBP, distributedType.Placement)
: new TensorConst(tv.AsTensor());
case TupleValue tpv:
return new TupleConst(tpv);
case ShapeValue sv:
return new ShapeConst(sv.Dimensions.ToArray());
case DimensionValue dv:
return new DimensionConst(dv.Dimension);
default:
throw new ArgumentOutOfRangeException(nameof(value));
}
}
}
6 changes: 5 additions & 1 deletion src/Nncase.Core/IR/Dimension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ public long FixedValue
/// Convert <see cref="Expr"/> to a <see cref="Dimension"/> expression.
/// </summary>
/// <param name="value">Dimension value.</param>
public static implicit operator Dimension(Expr value) => new(value);
public static implicit operator Dimension(Expr value) => value switch
{
DimensionConst dc => dc.Value,
_ => new(value),
};

public static bool operator ==(Dimension left, Dimension right)
{
Expand Down
1 change: 1 addition & 0 deletions src/Nncase.Core/IR/Expr.Operators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public partial class Expr
{
TensorConst tc => Tensor.FromScalar(tc.Value.ElementType, tc.Value[indices]),
TupleConst tc => tc.Value[(int)indices.Single()].AsTensor(),
ShapeConst sc => new DimensionConst(sc.Value[(int)indices.Single()]),
IR.Tuple t => t.Fields[(int)indices.Single()],
Call { Target: Concat { Axis: 0 } } c when indices.Length == 1 => c[Concat.Input][indices[0]][0],
Call { Target: Reshape } c when c[Reshape.Shape] is TensorConst tc && tc.Value.Length == 1 && tc.Value.ToScalar<long>() == 1 => c[Reshape.Input],
Expand Down
14 changes: 14 additions & 0 deletions src/Nncase.Core/IR/ExprCloner.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context)
);
}

/// <inheritdoc />
protected override Expr VisitLeafShapeConst(ShapeConst expr, TContext context)
{
return expr.With(
);
}

/// <inheritdoc />
protected override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context)
{
return expr.With(
);
}

/// <inheritdoc />
protected override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
{
Expand Down
24 changes: 24 additions & 0 deletions src/Nncase.Core/IR/ExprFunctor.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
/// </summary>
internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context);

/// <summary>
/// Visit <see cref="ShapeConst"/>.
/// </summary>
internal protected virtual TExprResult VisitShapeConst(ShapeConst expr, TContext context) => VisitConst(expr, context);

/// <summary>
/// Visit <see cref="DimensionConst"/>.
/// </summary>
internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr, TContext context) => VisitConst(expr, context);

/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
Expand Down Expand Up @@ -305,6 +315,20 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
/// <summary>
/// Visit <see cref="ShapeConst"/>.
/// </summary>
internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default);

/// <inheritdoc/>
internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr);
/// <summary>
/// Visit <see cref="DimensionConst"/>.
/// </summary>
internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default);

/// <inheritdoc/>
internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr);
/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
Expand Down
38 changes: 38 additions & 0 deletions src/Nncase.Core/IR/ExprRewriter.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext c
return RewriteLeafTensorConst(expr, context);
}

/// <inheritdoc/>
protected sealed override Expr VisitLeafShapeConst(ShapeConst expr, TContext context)
{
return RewriteLeafShapeConst(expr, context);
}

/// <inheritdoc/>
protected sealed override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context)
{
return RewriteLeafDimensionConst(expr, context);
}

/// <inheritdoc/>
protected sealed override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
{
Expand Down Expand Up @@ -320,6 +332,16 @@ protected sealed override Expr VisitLeafBufferOf(Buffers.BufferOf expr, TContext
/// </summary>
protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context);

/// <summary>
/// Rewrite leaf <see cref="ShapeConst"/>.
/// </summary>
protected virtual Expr RewriteLeafShapeConst(ShapeConst expr, TContext context) => RewriteLeafConst(expr, context);

/// <summary>
/// Rewrite leaf <see cref="DimensionConst"/>.
/// </summary>
protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr, TContext context) => RewriteLeafConst(expr, context);

/// <summary>
/// Rewrite leaf <see cref="IR.Tuple"/>.
/// </summary>
Expand Down Expand Up @@ -567,6 +589,22 @@ public partial class ExprRewriter
/// <inheritdoc />
protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr);

/// <summary>
/// Rewrite leaf <see cref="ShapeConst"/>.
/// </summary>
protected virtual Expr RewriteLeafShapeConst(ShapeConst expr) => RewriteLeafConst(expr);

/// <inheritdoc />
protected sealed override Expr RewriteLeafShapeConst(ShapeConst expr, Unit context) => RewriteLeafShapeConst(expr);

/// <summary>
/// Rewrite leaf <see cref="DimensionConst"/>.
/// </summary>
protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr) => RewriteLeafConst(expr);

/// <inheritdoc />
protected sealed override Expr RewriteLeafDimensionConst(DimensionConst expr, Unit context) => RewriteLeafDimensionConst(expr);

/// <summary>
/// Rewrite leaf <see cref="IR.Tuple"/>.
/// </summary>
Expand Down
54 changes: 54 additions & 0 deletions src/Nncase.Core/IR/ExprVisitor.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ protected internal override TExprResult VisitTensorConst(TensorConst expr, TCont
return VisitLeafTensorConst(expr, context);
}

/// <inheritdoc />
protected internal override TExprResult VisitShapeConst(ShapeConst expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafShapeConst(expr, context);
}

/// <inheritdoc />
protected internal override TExprResult VisitDimensionConst(DimensionConst expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafDimensionConst(expr, context);
}

/// <inheritdoc />
protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext context)
{
Expand Down Expand Up @@ -350,6 +364,16 @@ protected internal override TExprResult VisitBufferOf(Buffers.BufferOf expr, TCo
/// </summary>
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context);

/// <summary>
/// Visit leaf <see cref="ShapeConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr, TContext context) => VisitLeafConst(expr, context);

/// <summary>
/// Visit leaf <see cref="DimensionConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr, TContext context) => VisitLeafConst(expr, context);

/// <summary>
/// Visit leaf <see cref="IR.Tuple"/>.
/// </summary>
Expand Down Expand Up @@ -573,6 +597,20 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
/// <summary>
/// Visit <see cref="ShapeConst"/>.
/// </summary>
internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default);

/// <inheritdoc/>
internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr);
/// <summary>
/// Visit <see cref="DimensionConst"/>.
/// </summary>
internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default);

/// <inheritdoc/>
internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr);
/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
Expand Down Expand Up @@ -863,6 +901,22 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);

/// <summary>
/// Visit leaf <see cref="ShapeConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr) => base.VisitLeafShapeConst(expr, default);

/// <inheritdoc/>
protected sealed override TExprResult VisitLeafShapeConst(ShapeConst expr, Unit context) => VisitLeafShapeConst(expr);

/// <summary>
/// Visit leaf <see cref="DimensionConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr) => base.VisitLeafDimensionConst(expr, default);

/// <inheritdoc/>
protected sealed override TExprResult VisitLeafDimensionConst(DimensionConst expr, Unit context) => VisitLeafDimensionConst(expr);

/// <summary>
/// Visit leaf <see cref="IR.Tuple"/>.
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/IRList.csv
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ None,true,false,Default,,
Op,true,false,Default,,
PrimFunctionWrapper,true,true,BaseFunction,,Target
TensorConst,true,false,Const,,
ShapeConst,true,false,Const,,
DimensionConst,true,false,Const,,
Tuple,true,false,Default,IR.,@Fields
TupleConst,true,false,Const,,
MemSpan,true,false,Default,TIR.,Start;Size;
Expand Down
82 changes: 82 additions & 0 deletions src/Nncase.Core/IR/ShapeConst.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Nncase.IR;

/// <summary>
/// Constant of shape.
/// </summary>
public sealed class ShapeConst : Const, IEquatable<ShapeConst?>
{
public ShapeConst(Shape shape)
: base(new TensorType(DataTypes.Int64, new[] { shape.Rank }))
{
Value = shape;
}

public Shape Value { get; }

/// <inheritdoc/>
public override string ToString()
{
return Value.ToString();
}

/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitShapeConst(this, context);

public ShapeConst With(Shape? value = null)
{
return new ShapeConst(value ?? Value);
}

public bool Equals(ShapeConst? other) => other is ShapeConst o && Value.Equals(o.Value);

public override bool Equals(object? obj)
{
return Equals(obj as ShapeConst);
}
}

/// <summary>
/// Constant of tensor.
/// </summary>
public sealed class DimensionConst : Const, IEquatable<DimensionConst?>
{
public DimensionConst(Dimension value)
: base(new TensorType(DataTypes.Int64, Shape.Scalar))
{
Value = value;
}

public Dimension Value { get; }

/// <inheritdoc/>
public override string ToString()
{
return Value.ToString();
}

/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitDimensionConst(this, context);

public DimensionConst With(Dimension? value = null)
{
return new DimensionConst(value ?? Value);
}

public bool Equals(DimensionConst? other) => other is DimensionConst o && Value.Equals(o.Value);

public override bool Equals(object? obj)
{
return Equals(obj as DimensionConst);
}
}
Loading

0 comments on commit 1ca0154

Please sign in to comment.