diff --git a/src/Nncase.Core/IR/Const.cs b/src/Nncase.Core/IR/Const.cs index fa761ab306..071631ea1f 100644 --- a/src/Nncase.Core/IR/Const.cs +++ b/src/Nncase.Core/IR/Const.cs @@ -148,16 +148,20 @@ public static TensorConst FromTensor(Tensor tensor) /// Created constant expression. public static Const FromValue(IValue value) { - if (value is TensorValue tv) + switch (value) { - return tv.Type is DistributedType distributedType - ? new TensorConst(tv.AsTensor(), distributedType.NdSBP, distributedType.Placement) - : new TensorConst(tv.AsTensor()); - } - else - { - var tpv = (TupleValue)value; - return new TupleConst(tpv); + case TensorValue tv: + return tv.Type is DistributedType distributedType + ? new TensorConst(tv.AsTensor(), distributedType.NdSBP, distributedType.Placement) + : new TensorConst(tv.AsTensor()); + case TupleValue tpv: + return new TupleConst(tpv); + case ShapeValue sv: + return new ShapeConst(sv.Dimensions.ToArray()); + case DimensionValue dv: + return new DimensionConst(dv.Dimension); + default: + throw new ArgumentOutOfRangeException(nameof(value)); } } } diff --git a/src/Nncase.Core/IR/Dimension.cs b/src/Nncase.Core/IR/Dimension.cs index d527464393..d818fab574 100644 --- a/src/Nncase.Core/IR/Dimension.cs +++ b/src/Nncase.Core/IR/Dimension.cs @@ -114,7 +114,11 @@ public long FixedValue /// Convert to a expression. /// /// Dimension value. - public static implicit operator Dimension(Expr value) => new(value); + public static implicit operator Dimension(Expr value) => value switch + { + DimensionConst dc => dc.Value, + _ => new(value), + }; public static bool operator ==(Dimension left, Dimension right) { diff --git a/src/Nncase.Core/IR/Expr.Operators.cs b/src/Nncase.Core/IR/Expr.Operators.cs index 769c6da15c..14409ce2eb 100644 --- a/src/Nncase.Core/IR/Expr.Operators.cs +++ b/src/Nncase.Core/IR/Expr.Operators.cs @@ -31,6 +31,7 @@ public partial class Expr { TensorConst tc => Tensor.FromScalar(tc.Value.ElementType, tc.Value[indices]), TupleConst tc => tc.Value[(int)indices.Single()].AsTensor(), + ShapeConst sc => new DimensionConst(sc.Value[(int)indices.Single()]), IR.Tuple t => t.Fields[(int)indices.Single()], Call { Target: Concat { Axis: 0 } } c when indices.Length == 1 => c[Concat.Input][indices[0]][0], Call { Target: Reshape } c when c[Reshape.Shape] is TensorConst tc && tc.Value.Length == 1 && tc.Value.ToScalar() == 1 => c[Reshape.Input], diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs index e9673f0f47..902c3f4183 100644 --- a/src/Nncase.Core/IR/ExprCloner.g.cs +++ b/src/Nncase.Core/IR/ExprCloner.g.cs @@ -104,6 +104,20 @@ protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context) ); } + /// + protected override Expr VisitLeafShapeConst(ShapeConst expr, TContext context) + { + return expr.With( + ); + } + + /// + protected override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context) + { + return expr.With( + ); + } + /// protected override Expr VisitLeafTuple(IR.Tuple expr, TContext context) { diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs index 15ca14c5e3..b26a670a3b 100644 --- a/src/Nncase.Core/IR/ExprFunctor.g.cs +++ b/src/Nncase.Core/IR/ExprFunctor.g.cs @@ -68,6 +68,16 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitShapeConst(ShapeConst expr, TContext context) => VisitConst(expr, context); + + /// + /// Visit . + /// + internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr, TContext context) => VisitConst(expr, context); + /// /// Visit . /// @@ -305,6 +315,20 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default); + + /// + internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default); + + /// + internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default); diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs index 11a6754b79..a41f4a5d1b 100644 --- a/src/Nncase.Core/IR/ExprRewriter.g.cs +++ b/src/Nncase.Core/IR/ExprRewriter.g.cs @@ -79,6 +79,18 @@ protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext c return RewriteLeafTensorConst(expr, context); } + /// + protected sealed override Expr VisitLeafShapeConst(ShapeConst expr, TContext context) + { + return RewriteLeafShapeConst(expr, context); + } + + /// + protected sealed override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context) + { + return RewriteLeafDimensionConst(expr, context); + } + /// protected sealed override Expr VisitLeafTuple(IR.Tuple expr, TContext context) { @@ -320,6 +332,16 @@ protected sealed override Expr VisitLeafBufferOf(Buffers.BufferOf expr, TContext /// protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafShapeConst(ShapeConst expr, TContext context) => RewriteLeafConst(expr, context); + + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr, TContext context) => RewriteLeafConst(expr, context); + /// /// Rewrite leaf . /// @@ -567,6 +589,22 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafShapeConst(ShapeConst expr) => RewriteLeafConst(expr); + + /// + protected sealed override Expr RewriteLeafShapeConst(ShapeConst expr, Unit context) => RewriteLeafShapeConst(expr); + + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr) => RewriteLeafConst(expr); + + /// + protected sealed override Expr RewriteLeafDimensionConst(DimensionConst expr, Unit context) => RewriteLeafDimensionConst(expr); + /// /// Rewrite leaf . /// diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index a4dfbce2ce..32d702eb7d 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -88,6 +88,20 @@ protected internal override TExprResult VisitTensorConst(TensorConst expr, TCont return VisitLeafTensorConst(expr, context); } + /// + protected internal override TExprResult VisitShapeConst(ShapeConst expr, TContext context) + { + VisitOperands(expr, context); + return VisitLeafShapeConst(expr, context); + } + + /// + protected internal override TExprResult VisitDimensionConst(DimensionConst expr, TContext context) + { + VisitOperands(expr, context); + return VisitLeafDimensionConst(expr, context); + } + /// protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext context) { @@ -350,6 +364,16 @@ protected internal override TExprResult VisitBufferOf(Buffers.BufferOf expr, TCo /// protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr, TContext context) => VisitLeafConst(expr, context); + + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr, TContext context) => VisitLeafConst(expr, context); + /// /// Visit leaf . /// @@ -573,6 +597,20 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default); + + /// + internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default); + + /// + internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default); @@ -863,6 +901,22 @@ public partial class ExprVisitor /// protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr) => base.VisitLeafShapeConst(expr, default); + + /// + protected sealed override TExprResult VisitLeafShapeConst(ShapeConst expr, Unit context) => VisitLeafShapeConst(expr); + + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr) => base.VisitLeafDimensionConst(expr, default); + + /// + protected sealed override TExprResult VisitLeafDimensionConst(DimensionConst expr, Unit context) => VisitLeafDimensionConst(expr); + /// /// Visit leaf . /// diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv index dbca71c4e3..253041d488 100644 --- a/src/Nncase.Core/IR/IRList.csv +++ b/src/Nncase.Core/IR/IRList.csv @@ -9,6 +9,8 @@ None,true,false,Default,, Op,true,false,Default,, PrimFunctionWrapper,true,true,BaseFunction,,Target TensorConst,true,false,Const,, +ShapeConst,true,false,Const,, +DimensionConst,true,false,Const,, Tuple,true,false,Default,IR.,@Fields TupleConst,true,false,Const,, MemSpan,true,false,Default,TIR.,Start;Size; diff --git a/src/Nncase.Core/IR/ShapeConst.cs b/src/Nncase.Core/IR/ShapeConst.cs new file mode 100644 index 0000000000..e1f2d4abcf --- /dev/null +++ b/src/Nncase.Core/IR/ShapeConst.cs @@ -0,0 +1,82 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase.IR; + +/// +/// Constant of shape. +/// +public sealed class ShapeConst : Const, IEquatable +{ + public ShapeConst(Shape shape) + : base(new TensorType(DataTypes.Int64, new[] { shape.Rank })) + { + Value = shape; + } + + public Shape Value { get; } + + /// + public override string ToString() + { + return Value.ToString(); + } + + /// + public override TExprResult Accept(ExprFunctor functor, TContext context) + => functor.VisitShapeConst(this, context); + + public ShapeConst With(Shape? value = null) + { + return new ShapeConst(value ?? Value); + } + + public bool Equals(ShapeConst? other) => other is ShapeConst o && Value.Equals(o.Value); + + public override bool Equals(object? obj) + { + return Equals(obj as ShapeConst); + } +} + +/// +/// Constant of tensor. +/// +public sealed class DimensionConst : Const, IEquatable +{ + public DimensionConst(Dimension value) + : base(new TensorType(DataTypes.Int64, Shape.Scalar)) + { + Value = value; + } + + public Dimension Value { get; } + + /// + public override string ToString() + { + return Value.ToString(); + } + + /// + public override TExprResult Accept(ExprFunctor functor, TContext context) + => functor.VisitDimensionConst(this, context); + + public DimensionConst With(Dimension? value = null) + { + return new DimensionConst(value ?? Value); + } + + public bool Equals(DimensionConst? other) => other is DimensionConst o && Value.Equals(o.Value); + + public override bool Equals(object? obj) + { + return Equals(obj as DimensionConst); + } +} diff --git a/src/Nncase.Core/IValue.cs b/src/Nncase.Core/IValue.cs index 9089817fa5..f58cd1272b 100644 --- a/src/Nncase.Core/IValue.cs +++ b/src/Nncase.Core/IValue.cs @@ -92,14 +92,18 @@ public static TupleValue FromTensors(params Tensor[] tensors) /// Created value. public static IValue FromConst(Const @const) { - if (@const is TensorConst tc) + switch (@const) { - return FromTensor(tc.Value); - } - else - { - var tpc = (TupleConst)@const; - return tpc.Value; + case TensorConst tc: + return FromTensor(tc.Value); + case TupleConst tpc: + return tpc.Value; + case ShapeConst spc: + return new ShapeValue(spc.Value.ToArray()); + case DimensionConst dc: + return new DimensionValue(dc.Value); + default: + throw new ArgumentOutOfRangeException(nameof(@const)); } } } @@ -255,6 +259,109 @@ public override string ToString() } } +public sealed class DimensionValue : IValue, IEquatable +{ + private readonly Dimension _value; + + public DimensionValue(Dimension value) + { + _value = value; + } + + public IRType Type => new TensorType(DataTypes.Int64, Shape.Scalar); + + public int Count => 0; + + public Dimension Dimension => _value; + + public IValue this[int index] => throw new NotSupportedException("scalar can't index"); + + public Tensor AsTensor() => throw new NotImplementedException(); + + public Tensor[] AsTensors() => throw new NotImplementedException(); + + public bool Equals(DimensionValue? other) => EqualityComparer.Default.Equals(_value, other?._value); + + public override int GetHashCode() => EqualityComparer.Default.GetHashCode(_value); + + public override bool Equals(object? obj) + { + return Equals(obj as DimensionValue); + } + + public IEnumerator GetEnumerator() + { + yield break; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} + +public sealed class ShapeValue : IValue, IEquatable +{ + private readonly Dimension[] _values; + + public ShapeValue(params Dimension[] values) + { + _values = values; + } + + public ShapeValue(IEnumerable values) + { + var dims = new List(); + foreach (var item in values) + { + if (item is DimensionValue dimValue) + { + dims.Add(dimValue.Dimension); + } + else if (item is ShapeValue shapeValue) + { + dims.AddRange(shapeValue._values); + } + else + { + throw new NotSupportedException("only support dimension/shape value for constructor"); + } + } + + _values = dims.ToArray(); + } + + public IRType Type => new TensorType(DataTypes.Int64, new[] { _values.Length }); + + public int Count => _values.Length; + + public Span Dimensions => _values; + + public IValue this[int index] => new DimensionValue(_values[index]); + + public Tensor AsTensor() => throw new NotImplementedException(); + + public Tensor[] AsTensors() => throw new NotImplementedException(); + + public bool Equals(ShapeValue? other) => StructuralComparisons.StructuralEqualityComparer.Equals(_values, other?._values); + + public override int GetHashCode() => StructuralComparisons.StructuralEqualityComparer.GetHashCode(_values); + + public override bool Equals(object? obj) + { + return Equals(obj as ShapeValue); + } + + public IEnumerator GetEnumerator() + { + foreach (var item in _values) + { + yield return new DimensionValue(item); + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public override string ToString() => $"[{string.Join(",", _values.AsValueEnumerable().Select(v => v.ToString()))}]"; +} + /// /// Tuple value. /// diff --git a/src/Nncase.Core/Utilities/IRUtility.cs b/src/Nncase.Core/Utilities/IRUtility.cs new file mode 100644 index 0000000000..1ce420031d --- /dev/null +++ b/src/Nncase.Core/Utilities/IRUtility.cs @@ -0,0 +1,131 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +namespace Nncase.Utilities; + +public static class IRUtility +{ + /// + /// find the reshape's shape transform matrix. + /// + /// input shape. + /// new shape. + /// mat [new shape dim, old shpe dim]. + /// bool. + public static bool TryGetShapeMapMatrix(int[] inShape, int[] newShape, out int[,] mat) + { + int ProdIn(int[,] cmat, int i) + { + var prod = 1; + for (int j = 0; j < inShape.Length; j++) + { + var v = cmat[i, j] * inShape[j]; + if (v != 0) + { + prod *= v; + } + } + + return prod; + } + + int ProdOut(int[,] cmat, int j) + { + var prod = 1; + for (int i = 0; i < newShape.Length; i++) + { + var v = cmat[i, j] * newShape[i]; + if (v != 0) + { + prod *= v; + } + } + + return prod; + } + + mat = new int[newShape.Length, inShape.Length]; + int i = 0, j = 0; + var paths = new List<(int, int)>(); + while (i < newShape.Length && j < inShape.Length) + { + if (paths.IndexOf((i, j)) != -1) + { + return false; + } + + mat[i, j] = 1; + paths.Add((i, j)); + var inDim = ProdIn(mat, i); + var outDim = ProdOut(mat, j); + switch (inDim - outDim) + { + case 0: + i++; j++; + break; + case < 0: + j++; + break; + case > 0: + if (inDim % newShape[i] == 0) + { + i++; + } + else + { + mat[i, j] = 0; + j--; + paths.RemoveAt(paths.Count - 1); + } + + break; + } + } + + return i == newShape.Length && j == inShape.Length; + } + + /// + /// convert the mapping matrix as a dictionary. + /// the key is in dim, value is not dim. + /// + /// mat. + /// dict. + public static (Dictionary> Forward, Dictionary> Backward) ShapeMapMatrixAsDict(int[,] mat) + { + var forward = new Dictionary>(); + var backward = new Dictionary>(); + for (int i = 0; i < mat.GetLength(0); i++) + { + for (int j = 0; j < mat.GetLength(1); j++) + { + if (mat[i, j] == 0) + { + continue; + } + + if (!forward.TryGetValue(j, out var l1)) + { + l1 = new() { i }; + forward.Add(j, l1); + } + else + { + l1.Add(i); + } + + if (!backward.TryGetValue(i, out var l2)) + { + l2 = new() { j }; + backward.Add(i, l2); + } + else + { + l2.Add(j); + } + } + } + + return (forward, backward); + } +} diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index d04de6aec3..fad4ab909e 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -365,6 +365,8 @@ protected override string VisitConst(Const expr) { TensorConst tc => tc.Value.Shape.Size <= 8 ? tc.Value.GetArrayString(false) : string.Empty, TupleConst => string.Empty, + ShapeConst sc => VisitShape(sc.Value), + DimensionConst dc => VisitDimension(dc.Value), _ => throw new ArgumentOutOfRangeException(nameof(expr)), }; valueStr = valueStr != string.Empty ? " : " + valueStr : string.Empty; diff --git a/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs b/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs index 0774e2abc6..7a04dfd944 100644 --- a/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs +++ b/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs @@ -119,8 +119,7 @@ private void TryEvaluateAll() return enode.Expr switch { Var var => Visit(enode, var), - TensorConst con => Visit(enode, con), - TupleConst con => Visit(enode, con), + Const con => Visit(enode, con), Function func => Visit(enode, func), Call call => Visit(enode, call, returnType), IR.Tuple tuple => Visit(enode, tuple), @@ -138,7 +137,7 @@ private Cost Visit(ENode enode, Var var) return VisitLeaf(enode, () => Cost.Zero); } - private Cost Visit(ENode enode, TensorConst tc) + private Cost Visit(ENode enode, Const @const) { return VisitLeaf(enode, () => Cost.Zero); } @@ -153,11 +152,6 @@ private Cost Visit(ENode enode, Op op) return Visit(enode, costs => Cost.Zero); } - private Cost? Visit(ENode enode, TupleConst tc) - { - return Visit(enode, costs => costs.Sum()); - } - private Cost? Visit(ENode enode, IR.Tuple tuple) { return Visit(enode, costs => _accumulate ? costs.Sum() : Cost.Zero); diff --git a/src/Nncase.EGraph/Passes/EGraphExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractor.cs index cbe4249199..42c283042a 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractor.cs @@ -621,6 +621,12 @@ public Expr Visit(EClass root) case Marker mk: expr = mk.With(target: children[0], attribute: children[1], metadata: mk.Metadata); break; + case ShapeConst sc: + expr = sc; + break; + case DimensionConst dc: + expr = dc; + break; default: throw new NotSupportedException(enode.Expr.GetType().Name); } diff --git a/src/Nncase.EGraph/Passes/EGraphPrinter.cs b/src/Nncase.EGraph/Passes/EGraphPrinter.cs index 17f9cee6fd..2b3c13b29d 100644 --- a/src/Nncase.EGraph/Passes/EGraphPrinter.cs +++ b/src/Nncase.EGraph/Passes/EGraphPrinter.cs @@ -258,6 +258,8 @@ protected override string VisitConst(Const expr) { TensorConst tc => tc.Value.Shape.Size <= 8 ? tc.Value.GetArrayString(false) : string.Empty, TupleConst => string.Empty, + ShapeConst sc => VisitShape(sc.Value), + DimensionConst dc => VisitDimension(dc.Value), _ => throw new ArgumentOutOfRangeException(nameof(expr)), }; valueStr = valueStr != string.Empty ? " : " + valueStr : string.Empty; @@ -282,5 +284,23 @@ protected override string VisitOp(Op expr) protected override string VisitTuple(IR.Tuple expr) => "Tuple"; protected override string VisitNone(None expr) => "None"; + + private string VisitShape(Shape shape) => + shape.Kind switch + { + ShapeKind.Invalid => "Invalid", + ShapeKind.Unranked => "Unranked", + _ => $"[{string.Join(',', shape.Select(VisitDimension))}]", + }; + + private string VisitDimension(Dimension dimension) => + dimension.Kind switch + { + DimensionKind.Any => "any", + DimensionKind.Fixed => dimension.FixedValue.ToString(), + DimensionKind.Unknown => dimension.Value is Var var ? $"%{var.Name}" : "?", + _ => throw new NotSupportedException(dimension.Kind.ToString()), + }; + } } diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs index 4343cd168b..3d78be5d7b 100755 --- a/src/Nncase.Evaluator/Math/Binary.cs +++ b/src/Nncase.Evaluator/Math/Binary.cs @@ -79,45 +79,61 @@ public static IRType CheckSBP(BinaryOp op, TensorType tensorType, DistributedTyp /// public IValue Visit(IEvaluateContext context, Binary binary) { - var lhs = context.GetArgumentValueAsTensor(binary, Binary.Lhs); - var rhs = context.GetArgumentValueAsTensor(binary, Binary.Rhs); - if (lhs.Shape.IsScalar && rhs.Shape.IsScalar) + var lhsValue = context.GetArgumentValue(binary, Binary.Lhs); + var rhsValue = context.GetArgumentValue(binary, Binary.Rhs); + switch (lhsValue, rhsValue) { - if (lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.Int64 && rhs.ElementType == DataTypes.Int64) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.Float32 && rhs.ElementType == DataTypes.Float32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.Boolean && rhs.ElementType == DataTypes.Boolean) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.UInt32 && rhs.ElementType == DataTypes.UInt32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64)) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else - { + case (TensorValue lhsTV, TensorValue rhsTV): + var lhs = lhsTV.AsTensor(); + var rhs = rhsTV.AsTensor(); + if (lhs.Shape.IsScalar && rhs.Shape.IsScalar) + { + if (lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.Int64 && rhs.ElementType == DataTypes.Int64) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.Float32 && rhs.ElementType == DataTypes.Float32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.Boolean && rhs.ElementType == DataTypes.Boolean) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.UInt32 && rhs.ElementType == DataTypes.UInt32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64)) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else + { + return Ort_compute(binary, lhs, rhs); + } + } + return Ort_compute(binary, lhs, rhs); - } + case (DimensionValue { Dimension: { IsFixed: true } } lhsDV, TensorValue { Type: TensorType { Shape: { IsScalar: true } } } rhsTV): + return Value.FromTensor(Compute(binary.BinaryOp, lhsDV.Dimension.FixedValue, rhsTV.AsTensor().ToScalar())); + case (TensorValue { Type: TensorType { Shape: { IsScalar: true } } } lhsTV, DimensionValue { Dimension: { IsFixed: true } } rhsDV): + return Value.FromTensor(Compute(binary.BinaryOp, lhsTV.AsTensor().ToScalar(), rhsDV.Dimension.FixedValue)); + case (DimensionValue { Dimension: { IsFixed: true } } lhsDV, DimensionValue { Dimension: { IsFixed: true } } rhsDV): + return Value.FromTensor(Compute(binary.BinaryOp, lhsDV.Dimension.FixedValue, rhsDV.Dimension.FixedValue)); + default: + break; } - return Ort_compute(binary, lhs, rhs); + throw new NotSupportedException($"binary notsupport {lhsValue} {rhsValue}"); } /// diff --git a/src/Nncase.Evaluator/Math/Compare.cs b/src/Nncase.Evaluator/Math/Compare.cs index 69322ec184..0d34dc5b4b 100644 --- a/src/Nncase.Evaluator/Math/Compare.cs +++ b/src/Nncase.Evaluator/Math/Compare.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using System.Linq; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.Math; @@ -67,25 +68,31 @@ public static IRType CheckSBP(TensorType tensorType, DistributedType a, Distribu /// public IValue Visit(IEvaluateContext context, Compare target) { - var lhs = context.GetArgumentValueAsTensor(target, Compare.Lhs); - var rhs = context.GetArgumentValueAsTensor(target, Compare.Rhs); - if (lhs.Shape.IsScalar && rhs.Shape.IsScalar && lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) + var lhsValue = context.GetArgumentValue(target, Compare.Lhs); + var rhsValue = context.GetArgumentValue(target, Compare.Rhs); + switch (lhsValue, rhsValue) { - return Value.FromTensor(Tensor.FromScalar(Compute(target.CompareOp, lhs.ToScalar(), rhs.ToScalar()))); + case (TensorValue lhsTV, TensorValue rhsTV): + var lhs = lhsTV.AsTensor(); + var rhs = rhsTV.AsTensor(); + if (lhs.Shape.IsScalar && rhs.Shape.IsScalar && lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(target.CompareOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else + { + return Compute(target.CompareOp, lhs.ToOrtTensor(), rhs.ToOrtTensor()); + } + + case (ShapeValue lhsTV, TensorValue rhsTV): + return Value.FromTensor(Tensor.FromArray(lhsTV.Dimensions.ToArray().Zip(rhsTV.AsTensor().ToArray()).Select(p => Compute(target.CompareOp, p.First, p.Second)).ToArray())); + case (TensorValue lhsTV, ShapeValue rhsTV): + return Value.FromTensor(Tensor.FromArray(lhsTV.AsTensor().ToArray().Zip(rhsTV.Dimensions.ToArray()).Select(p => Compute(target.CompareOp, p.First, p.Second)).ToArray())); + default: + break; } - var a = context.GetOrtArgumentValue(target, Compare.Lhs); - var b = context.GetOrtArgumentValue(target, Compare.Rhs); - return target.CompareOp switch - { - CompareOp.Equal => OrtKI.Equal(a, b).ToValue(), - CompareOp.LowerOrEqual => OrtKI.LessOrEqual(a, b).ToValue(), - CompareOp.GreaterOrEqual => OrtKI.GreaterOrEqual(a, b).ToValue(), - CompareOp.GreaterThan => OrtKI.Greater(a, b).ToValue(), - CompareOp.LowerThan => OrtKI.Less(a, b).ToValue(), - CompareOp.NotEqual => OrtKI.Not(OrtKI.Equal(a, b)).ToValue(), - _ => throw new ArgumentOutOfRangeException(target.CompareOp.ToString()), - }; + throw new NotSupportedException(); } /// @@ -149,7 +156,41 @@ public Expr Visit(IShapeEvaluateContext context, Compare target) return ShapeExprUtility.BroadcastShape(lhs, rhs); } - private bool Compute(CompareOp op, int a, int b) => op switch + private bool Compute(CompareOp op, Dimension a, long b) + { + return (a, b) switch + { + (Dimension { IsFixed: true } da, _) => Compute(op, da.FixedValue, b), + (Dimension { IsFixed: false } da, _) => Compute(op, da.Value, b), + }; + } + + private bool Compute(CompareOp op, long a, Dimension b) + { + return (a, b) switch + { + (_, Dimension { IsFixed: true } db) => Compute(op, a, db.FixedValue), + (_, Dimension { IsFixed: false } db) => Compute(op, a, db.Value), + }; + } + + private IValue Compute(CompareOp op, OrtKISharp.Tensor a, OrtKISharp.Tensor b) + { + return op switch + { + CompareOp.Equal => OrtKI.Equal(a, b).ToValue(), + CompareOp.LowerOrEqual => OrtKI.LessOrEqual(a, b).ToValue(), + CompareOp.GreaterOrEqual => OrtKI.GreaterOrEqual(a, b).ToValue(), + CompareOp.GreaterThan => OrtKI.Greater(a, b).ToValue(), + CompareOp.LowerThan => OrtKI.Less(a, b).ToValue(), + CompareOp.NotEqual => OrtKI.Not(OrtKI.Equal(a, b)).ToValue(), + _ => throw new ArgumentOutOfRangeException(op.ToString()), + }; + } + + private bool Compute(CompareOp op, T a, T b) + where T : System.Numerics.IEqualityOperators, System.Numerics.IComparisonOperators + => op switch { CompareOp.Equal => a == b, CompareOp.LowerOrEqual => a <= b, @@ -160,6 +201,20 @@ public Expr Visit(IShapeEvaluateContext context, Compare target) _ => throw new ArgumentOutOfRangeException(nameof(op)), }; + private bool Compute(CompareOp op, Expr a, long b) + => (op, a, b) switch + { + (CompareOp.Equal, Expr, -1) => false, + _ => throw new ArgumentOutOfRangeException(nameof(op)), + }; + + private bool Compute(CompareOp op, long a, Expr b) + => (op, a, b) switch + { + (CompareOp.Equal, -1, Expr) => false, + _ => throw new ArgumentOutOfRangeException(nameof(op)), + }; + private IRType Visit(TensorType lhs, TensorType rhs) { var broadcastType = TypeInference.BroadcastType(lhs, rhs); diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs index 0a30397101..3c74ad6755 100644 --- a/src/Nncase.Evaluator/Tensors/Concat.cs +++ b/src/Nncase.Evaluator/Tensors/Concat.cs @@ -23,11 +23,45 @@ public class ConcatEvaluator : IEvaluator, ITypeInferencer, ICos IShapeEvaluator, IMetricEvaluator { /// - public IValue Visit(IEvaluateContext context, Concat cat) + public IValue Visit(IEvaluateContext context, Concat target) { - var inputs = context.GetArgumentValueAsTensors(cat, Concat.Input); - var axis = cat.Axis; - return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue(); + var inputValue = context.GetArgumentValue(target, Concat.Input); + var axis = target.Axis; + switch (inputValue) + { + case TupleValue tpv: + if (tpv.All(v => v is TensorValue)) + { + var inputs = tpv.AsTensors(); + return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue(); + } + else if (tpv.Any(v => v is ShapeValue)) + { + var dims = new List(); + foreach (var fv in tpv) + { + switch (fv) + { + case TensorValue ftv: + dims.Add(new Dimension(ftv.AsTensor().Cast()[axis])); + break; + case ShapeValue fsv: + dims.Add(fsv.Dimensions[axis]); + break; + default: + throw new ArgumentOutOfRangeException(nameof(target), "ShapeValue's field not support"); + } + } + + return new ShapeValue(dims.ToArray()); + } + + break; + default: + break; + } + + throw new ArgumentOutOfRangeException(nameof(target)); } /// diff --git a/src/Nncase.Evaluator/Tensors/Gather.cs b/src/Nncase.Evaluator/Tensors/Gather.cs index 3aebcbd9b4..2d6e9e0965 100644 --- a/src/Nncase.Evaluator/Tensors/Gather.cs +++ b/src/Nncase.Evaluator/Tensors/Gather.cs @@ -22,10 +22,33 @@ public class GatherEvaluator : IEvaluator, ITypeInferencer, ICos /// public IValue Visit(IEvaluateContext context, Gather gather) { - var input = context.GetOrtArgumentValue(gather, Gather.Input); + var inputValue = context.GetArgumentValue(gather, Gather.Input); + var indexValue = context.GetArgumentValue(gather, Gather.Index); var axis = gather.Axis; - var index = context.GetOrtArgumentValue(gather, Gather.Index); - return OrtKI.Gather(input, index, axis).ToValue(); + switch (inputValue, indexValue) + { + case (_, TensorValue indexTValue): + if (inputValue is TensorValue inputTValue) + { + return OrtKI.Gather(inputTValue.AsTensor().ToOrtTensor(), indexTValue.AsTensor().ToOrtTensor(), axis).ToValue(); + } + else if (inputValue is ShapeValue inputSValue && axis == 0) + { + var indexTensor = indexTValue.AsTensor(); + if (!indexTensor.Shape.IsScalar) + { + throw new NotSupportedException("Gather ShapeConst the index must be scalar!"); + } + + return inputSValue[indexTensor.ToScalar()]; + } + + break; + default: + break; + } + + throw new NotSupportedException(); } /// diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index ac28fcbfe3..fb1959dc2d 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -97,6 +97,7 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i var shapeDims = new Shape(Enumerable.Range(0, rank).Select(i => (Dimension)shape[i]).ToArray()); var outputShape = new Dimension[rank]; + // todo use egraph simplify. var minus1Dim = FixedAndDynamicDimension.Abs(input.Shape.ProdFixedAndDynamic() / shapeDims.ProdFixedAndDynamic()); for (var i = 0; i < rank; i++) { @@ -107,13 +108,21 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i } else { - var outputDim = ShapeExprUtility.If( - Equal(shapeDim.Value, -1L), - (shapeDim, minus1Dim) => minus1Dim, - (shapeDim, minus1Dim) => shapeDim, - shapeDim.Value, - minus1Dim.ToExpr()); - outputShape[i] = outputDim; + switch (shapeDim) + { + case Dimension { Value: Var }: + outputShape[i] = shapeDim; + break; + default: + var outputDim = ShapeExprUtility.If( + Equal(shapeDim.Value, -1L), + (shapeDim, minus1Dim) => minus1Dim, + (shapeDim, minus1Dim) => shapeDim, + shapeDim.Value, + minus1Dim.ToExpr()); + outputShape[i] = outputDim; + break; + } } } diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs index 7b16051bf2..b1763790a3 100644 --- a/src/Nncase.Evaluator/Tensors/Slice.cs +++ b/src/Nncase.Evaluator/Tensors/Slice.cs @@ -30,13 +30,55 @@ public class SliceEvaluator : IEvaluator, ITypeInferencer, ICostEv /// public IValue Visit(IEvaluateContext context, Slice sl) { - var input = context.GetOrtArgumentValue(sl, Slice.Input); - var begins = context.GetInt64OrtTensorArgumentValue(sl, Slice.Begins); - var ends = context.GetInt64OrtTensorArgumentValue(sl, Slice.Ends); - var axes = context.GetInt64OrtTensorArgumentValue(sl, Slice.Axes); - var strides = context.GetInt64OrtTensorArgumentValue(sl, Slice.Strides); - var sliced = OrtKI.Slice(input, begins, ends, axes, strides); - return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType)); + var inputValue = context.GetArgumentValue(sl, Slice.Input); + var beginsValue = context.GetArgumentValue(sl, Slice.Begins); + var endsValue = context.GetArgumentValue(sl, Slice.Ends); + var axesValue = context.GetArgumentValue(sl, Slice.Axes); + var stridesValue = context.GetArgumentValue(sl, Slice.Strides); + switch (inputValue, beginsValue, endsValue, axesValue, stridesValue) + { + case (_, TensorValue beginsTValue, TensorValue endsTValue, TensorValue axesTValue, TensorValue stridesTValue): + var beginsTensor = beginsTValue.AsTensor(); + var endsTensor = endsTValue.AsTensor(); + var axesTensor = axesTValue.AsTensor(); + var stridesTensor = stridesTValue.AsTensor(); + if (inputValue is ShapeValue inputSValue && beginsTensor.Shape.Rank == 1 && endsTensor.Shape.Rank == 1 && axesTensor.Shape.Rank == 1 && stridesTensor.Shape.Rank == 1) + { + // var input = inputShapeValue.AsTensor().Cast().ToOrtTensor(); + var begins = beginsTensor.ToScalar(); + var ends = endsTensor.ToScalar(); + var axes = axesTensor.ToScalar(); + if (axes != 0) + { + throw new NotSupportedException("slice ShapeConst Axes != 0"); + } + + var strides = stridesTensor.ToScalar(); + var sliced = new List(); + for (long i = begins; i < ends; i += strides) + { + sliced.Add(inputSValue[checked((int)i)]); + } + + return new ShapeValue(sliced); + } + else if (inputValue is TensorValue inputTValue) + { + var input = inputTValue.AsTensor().Cast().ToOrtTensor(); + var begins = beginsTensor.Cast().ToOrtTensor(); + var ends = endsTensor.Cast().ToOrtTensor(); + var axes = axesTensor.Cast().ToOrtTensor(); + var strides = stridesTensor.Cast().ToOrtTensor(); + var sliced = OrtKI.Slice(input, begins, ends, axes, strides); + return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType)); + } + + break; + default: + break; + } + + throw new NotSupportedException("input value is neither shapevalue or tensorvalue"); } /// diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs index 76e1c7ad5b..c87c93d0e3 100644 --- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs +++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs @@ -19,9 +19,35 @@ public class UnsqueezeEvaluator : IEvaluator, ITypeInferencer public IValue Visit(IEvaluateContext context, Unsqueeze unSqueeze) { - var input = context.GetOrtArgumentValue(unSqueeze, Unsqueeze.Input); - var axes = context.GetInt64OrtTensorArgumentValue(unSqueeze, Unsqueeze.Dim); - return Value.FromTensor(OrtKI.Unsqueeze(input, axes).ToTensor(context.CurrentCall.CheckedTensorType)); + var inputValue = context.GetArgumentValue(unSqueeze, Unsqueeze.Input); + var axesValue = context.GetArgumentValue(unSqueeze, Unsqueeze.Dim); + + switch (inputValue, axesValue) + { + case (_, TensorValue axesTValue): + var axesTensor = axesTValue.AsTensor(); + if (inputValue is TensorValue inputTValue) + { + var input = inputTValue.AsTensor().ToOrtTensor(); + var axes = axesTensor.Cast().ToOrtTensor(); + return Value.FromTensor(OrtKI.Unsqueeze(input, axes).ToTensor(context.CurrentCall.CheckedTensorType)); + } + else if (inputValue is DimensionValue inputDValue) + { + if (axesTensor.Shape.Rank > 1 || axesTensor.ToScalar() != 0) + { + throw new NotSupportedException("only support scalar dim when input is DimensionValue!"); + } + + return new ShapeValue(new[] { inputDValue }); + } + + break; + default: + break; + } + + throw new NotSupportedException(); } /// diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs index d80c9fcd19..7258d271e4 100644 --- a/src/Nncase.Evaluator/Tensors/Where.cs +++ b/src/Nncase.Evaluator/Tensors/Where.cs @@ -23,24 +23,41 @@ public class WhereEvaluator : IEvaluator, ITypeInferencer, ICostEv /// public IValue Visit(IEvaluateContext context, Where where) { - var xt = context.GetArgumentValueAsTensor(where, Where.X); - var yt = context.GetArgumentValueAsTensor(where, Where.Y); - if (where.IsTfWhere) + var condValue = context.GetArgumentValue(where, Where.Cond); + var xValue = context.GetArgumentValue(where, Where.X); + var yValue = context.GetArgumentValue(where, Where.Y); + switch (condValue, xValue, yValue) { - var condTensor = context.GetArgumentValueAsTensor(where, Where.Cond); - if (condTensor.Rank > 1) - { - throw new NotImplementedException(); - } - - var result = condTensor.Select((b, i) => (b, i)).Where(t => t.b).Select(t => (long)t.i).ToArray(); - return Value.FromTensor(Tensor.From(result, new Shape(result.Length, condTensor.Rank))); + case (TensorValue condTV, TensorValue xTV, _): + if (yValue is TensorValue yTV) + { + if (where.IsTfWhere) + { + var condTensor = condTV.AsTensor().Cast(); + if (condTensor.Rank > 1) + { + throw new NotImplementedException(); + } + + var result = condTensor.Select((b, i) => (b, i)).Where(t => t.b).Select(t => (long)t.i).ToArray(); + return Value.FromTensor(Tensor.From(result, new Shape(result.Length, condTensor.Rank))); + } + else + { + return OrtKI.Where(condTV.AsTensor().ToOrtTensor(), xTV.AsTensor().ToOrtTensor(), yTV.AsTensor().ToOrtTensor()).ToValue(); + } + } + else if (yValue is ShapeValue ySV) + { + return new ShapeValue(condTV.AsTensor().Cast().Zip(xValue.AsTensor().Cast().Zip(ySV.Dimensions.ToArray())).Select(tp => tp.First ? new Dimension(tp.Second.First) : tp.Second.Second).ToArray()); + } + + break; + default: + break; } - var cond = context.GetOrtArgumentValue(where, Where.Cond); - var x = context.GetOrtArgumentValue(where, Where.X); - var y = context.GetOrtArgumentValue(where, Where.Y); - return OrtKI.Where(cond, x, y).ToValue(); + throw new NotSupportedException(); } /// diff --git a/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs b/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs index 1d6f507d3e..bdcf65f8c7 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs @@ -48,10 +48,39 @@ private Const GetReplace(Call call, IReadOnlyList constArgs) public partial class FoldShapeOf : RewriteRule { /// - public override CallPattern Pattern { get; } = IsShapeOf(IsWildcard("wc") with { TypePattern = HasFixedShape() }); + public override CallPattern Pattern { get; } = IsShapeOf(IsWildcard("wc") with { TypePattern = HasRank() }); private Const GetReplace(Expr wc) { - return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray()); + if (wc.CheckedShape.IsFixed) + { + return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray()); + } + else + { + return new ShapeConst(wc.CheckedShape); + } } } + +/// +/// concat([1, seq_len, 2, 3], 0) => tuple([1,seq_len,2,3]). +/// +// [RuleGenerator] +// public partial class FoldConcatShape : RewriteRule +// { +// /// +// public override CallPattern Pattern { get; } = IsConcat("concat", _ => true, IsTuple("tuple", IsVArgsRepeat("fileds", exprs => +// { +// var patterns = new Pattern[exprs.Length]; +// for (int i = 0; i < exprs.Length; i++) +// { +// patterns[i] = IsWildcard($"input_{i}") with { TypePattern = IsIntegralScalar() }; +// } +// return patterns; +// }))); +// private Const GetReplace(Expr wc) +// { +// return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray()); +// } +// } diff --git a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs index 261543a43b..0e290ae13c 100644 --- a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs +++ b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs @@ -326,6 +326,23 @@ public void TestBroadcastInfer2() Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimUnk1, 8192 }), result); } + [Fact] + public void TestReshapeInfer() + { + var dimVar = new Var("seq_len", new TensorType(DataTypes.Int64, Shape.Scalar)); + var dimC = new Dimension(dimVar); + var a = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 128 })); + var constShape = new ShapeConst(new[] { 1, dimC, 2, 64 }); + var reshape = Reshape(a, constShape); + var result = reshape.CheckedType; + Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 2, 64 }), result); + + var b = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 14, 64 })); + var reshapeb = Reshape(b, new ShapeConst(new[] { 1, dimC, -1 })); + var resultb = reshapeb.CheckedType; + Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 896 }), resultb); + } + private void CheckInferShape(Expr expr, params Dimension[] shapeDimensions) { CheckInferShape(expr, new Shape(shapeDimensions));