diff --git a/src/Nncase.Core/IR/Const.cs b/src/Nncase.Core/IR/Const.cs
index fa761ab306..071631ea1f 100644
--- a/src/Nncase.Core/IR/Const.cs
+++ b/src/Nncase.Core/IR/Const.cs
@@ -148,16 +148,20 @@ public static TensorConst FromTensor(Tensor tensor)
/// Created constant expression.
public static Const FromValue(IValue value)
{
- if (value is TensorValue tv)
+ switch (value)
{
- return tv.Type is DistributedType distributedType
- ? new TensorConst(tv.AsTensor(), distributedType.NdSBP, distributedType.Placement)
- : new TensorConst(tv.AsTensor());
- }
- else
- {
- var tpv = (TupleValue)value;
- return new TupleConst(tpv);
+ case TensorValue tv:
+ return tv.Type is DistributedType distributedType
+ ? new TensorConst(tv.AsTensor(), distributedType.NdSBP, distributedType.Placement)
+ : new TensorConst(tv.AsTensor());
+ case TupleValue tpv:
+ return new TupleConst(tpv);
+ case ShapeValue sv:
+ return new ShapeConst(sv.Dimensions.ToArray());
+ case DimensionValue dv:
+ return new DimensionConst(dv.Dimension);
+ default:
+ throw new ArgumentOutOfRangeException(nameof(value));
}
}
}
diff --git a/src/Nncase.Core/IR/Dimension.cs b/src/Nncase.Core/IR/Dimension.cs
index d527464393..d818fab574 100644
--- a/src/Nncase.Core/IR/Dimension.cs
+++ b/src/Nncase.Core/IR/Dimension.cs
@@ -114,7 +114,11 @@ public long FixedValue
/// Convert to a expression.
///
/// Dimension value.
- public static implicit operator Dimension(Expr value) => new(value);
+ public static implicit operator Dimension(Expr value) => value switch
+ {
+ DimensionConst dc => dc.Value,
+ _ => new(value),
+ };
public static bool operator ==(Dimension left, Dimension right)
{
diff --git a/src/Nncase.Core/IR/Expr.Operators.cs b/src/Nncase.Core/IR/Expr.Operators.cs
index 769c6da15c..14409ce2eb 100644
--- a/src/Nncase.Core/IR/Expr.Operators.cs
+++ b/src/Nncase.Core/IR/Expr.Operators.cs
@@ -31,6 +31,7 @@ public partial class Expr
{
TensorConst tc => Tensor.FromScalar(tc.Value.ElementType, tc.Value[indices]),
TupleConst tc => tc.Value[(int)indices.Single()].AsTensor(),
+ ShapeConst sc => new DimensionConst(sc.Value[(int)indices.Single()]),
IR.Tuple t => t.Fields[(int)indices.Single()],
Call { Target: Concat { Axis: 0 } } c when indices.Length == 1 => c[Concat.Input][indices[0]][0],
Call { Target: Reshape } c when c[Reshape.Shape] is TensorConst tc && tc.Value.Length == 1 && tc.Value.ToScalar() == 1 => c[Reshape.Input],
diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs
index e9673f0f47..902c3f4183 100644
--- a/src/Nncase.Core/IR/ExprCloner.g.cs
+++ b/src/Nncase.Core/IR/ExprCloner.g.cs
@@ -104,6 +104,20 @@ protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context)
);
}
+ ///
+ protected override Expr VisitLeafShapeConst(ShapeConst expr, TContext context)
+ {
+ return expr.With(
+ );
+ }
+
+ ///
+ protected override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context)
+ {
+ return expr.With(
+ );
+ }
+
///
protected override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
{
diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs
index 15ca14c5e3..b26a670a3b 100644
--- a/src/Nncase.Core/IR/ExprFunctor.g.cs
+++ b/src/Nncase.Core/IR/ExprFunctor.g.cs
@@ -68,6 +68,16 @@ public partial class ExprFunctor
///
internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context);
+ ///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitShapeConst(ShapeConst expr, TContext context) => VisitConst(expr, context);
+
+ ///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr, TContext context) => VisitConst(expr, context);
+
///
/// Visit .
///
@@ -305,6 +315,20 @@ public partial class ExprFunctor
///
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default);
+
+ ///
+ internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr);
+ ///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default);
+
+ ///
+ internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr);
+ ///
/// Visit .
///
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs
index 11a6754b79..a41f4a5d1b 100644
--- a/src/Nncase.Core/IR/ExprRewriter.g.cs
+++ b/src/Nncase.Core/IR/ExprRewriter.g.cs
@@ -79,6 +79,18 @@ protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext c
return RewriteLeafTensorConst(expr, context);
}
+ ///
+ protected sealed override Expr VisitLeafShapeConst(ShapeConst expr, TContext context)
+ {
+ return RewriteLeafShapeConst(expr, context);
+ }
+
+ ///
+ protected sealed override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context)
+ {
+ return RewriteLeafDimensionConst(expr, context);
+ }
+
///
protected sealed override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
{
@@ -320,6 +332,16 @@ protected sealed override Expr VisitLeafBufferOf(Buffers.BufferOf expr, TContext
///
protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context);
+ ///
+ /// Rewrite leaf .
+ ///
+ protected virtual Expr RewriteLeafShapeConst(ShapeConst expr, TContext context) => RewriteLeafConst(expr, context);
+
+ ///
+ /// Rewrite leaf .
+ ///
+ protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr, TContext context) => RewriteLeafConst(expr, context);
+
///
/// Rewrite leaf .
///
@@ -567,6 +589,22 @@ public partial class ExprRewriter
///
protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr);
+ ///
+ /// Rewrite leaf .
+ ///
+ protected virtual Expr RewriteLeafShapeConst(ShapeConst expr) => RewriteLeafConst(expr);
+
+ ///
+ protected sealed override Expr RewriteLeafShapeConst(ShapeConst expr, Unit context) => RewriteLeafShapeConst(expr);
+
+ ///
+ /// Rewrite leaf .
+ ///
+ protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr) => RewriteLeafConst(expr);
+
+ ///
+ protected sealed override Expr RewriteLeafDimensionConst(DimensionConst expr, Unit context) => RewriteLeafDimensionConst(expr);
+
///
/// Rewrite leaf .
///
diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs
index a4dfbce2ce..32d702eb7d 100644
--- a/src/Nncase.Core/IR/ExprVisitor.g.cs
+++ b/src/Nncase.Core/IR/ExprVisitor.g.cs
@@ -88,6 +88,20 @@ protected internal override TExprResult VisitTensorConst(TensorConst expr, TCont
return VisitLeafTensorConst(expr, context);
}
+ ///
+ protected internal override TExprResult VisitShapeConst(ShapeConst expr, TContext context)
+ {
+ VisitOperands(expr, context);
+ return VisitLeafShapeConst(expr, context);
+ }
+
+ ///
+ protected internal override TExprResult VisitDimensionConst(DimensionConst expr, TContext context)
+ {
+ VisitOperands(expr, context);
+ return VisitLeafDimensionConst(expr, context);
+ }
+
///
protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext context)
{
@@ -350,6 +364,16 @@ protected internal override TExprResult VisitBufferOf(Buffers.BufferOf expr, TCo
///
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context);
+ ///
+ /// Visit leaf .
+ ///
+ protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr, TContext context) => VisitLeafConst(expr, context);
+
+ ///
+ /// Visit leaf .
+ ///
+ protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr, TContext context) => VisitLeafConst(expr, context);
+
///
/// Visit leaf .
///
@@ -573,6 +597,20 @@ public partial class ExprVisitor
///
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default);
+
+ ///
+ internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr);
+ ///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default);
+
+ ///
+ internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr);
+ ///
/// Visit .
///
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
@@ -863,6 +901,22 @@ public partial class ExprVisitor
///
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);
+ ///
+ /// Visit leaf .
+ ///
+ protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr) => base.VisitLeafShapeConst(expr, default);
+
+ ///
+ protected sealed override TExprResult VisitLeafShapeConst(ShapeConst expr, Unit context) => VisitLeafShapeConst(expr);
+
+ ///
+ /// Visit leaf .
+ ///
+ protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr) => base.VisitLeafDimensionConst(expr, default);
+
+ ///
+ protected sealed override TExprResult VisitLeafDimensionConst(DimensionConst expr, Unit context) => VisitLeafDimensionConst(expr);
+
///
/// Visit leaf .
///
diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv
index dbca71c4e3..253041d488 100644
--- a/src/Nncase.Core/IR/IRList.csv
+++ b/src/Nncase.Core/IR/IRList.csv
@@ -9,6 +9,8 @@ None,true,false,Default,,
Op,true,false,Default,,
PrimFunctionWrapper,true,true,BaseFunction,,Target
TensorConst,true,false,Const,,
+ShapeConst,true,false,Const,,
+DimensionConst,true,false,Const,,
Tuple,true,false,Default,IR.,@Fields
TupleConst,true,false,Const,,
MemSpan,true,false,Default,TIR.,Start;Size;
diff --git a/src/Nncase.Core/IR/ShapeConst.cs b/src/Nncase.Core/IR/ShapeConst.cs
new file mode 100644
index 0000000000..e1f2d4abcf
--- /dev/null
+++ b/src/Nncase.Core/IR/ShapeConst.cs
@@ -0,0 +1,82 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Nncase.IR;
+
+///
+/// Constant of shape.
+///
+public sealed class ShapeConst : Const, IEquatable
+{
+ public ShapeConst(Shape shape)
+ : base(new TensorType(DataTypes.Int64, new[] { shape.Rank }))
+ {
+ Value = shape;
+ }
+
+ public Shape Value { get; }
+
+ ///
+ public override string ToString()
+ {
+ return Value.ToString();
+ }
+
+ ///
+ public override TExprResult Accept(ExprFunctor functor, TContext context)
+ => functor.VisitShapeConst(this, context);
+
+ public ShapeConst With(Shape? value = null)
+ {
+ return new ShapeConst(value ?? Value);
+ }
+
+ public bool Equals(ShapeConst? other) => other is ShapeConst o && Value.Equals(o.Value);
+
+ public override bool Equals(object? obj)
+ {
+ return Equals(obj as ShapeConst);
+ }
+}
+
+///
+/// Constant of tensor.
+///
+public sealed class DimensionConst : Const, IEquatable
+{
+ public DimensionConst(Dimension value)
+ : base(new TensorType(DataTypes.Int64, Shape.Scalar))
+ {
+ Value = value;
+ }
+
+ public Dimension Value { get; }
+
+ ///
+ public override string ToString()
+ {
+ return Value.ToString();
+ }
+
+ ///
+ public override TExprResult Accept(ExprFunctor functor, TContext context)
+ => functor.VisitDimensionConst(this, context);
+
+ public DimensionConst With(Dimension? value = null)
+ {
+ return new DimensionConst(value ?? Value);
+ }
+
+ public bool Equals(DimensionConst? other) => other is DimensionConst o && Value.Equals(o.Value);
+
+ public override bool Equals(object? obj)
+ {
+ return Equals(obj as DimensionConst);
+ }
+}
diff --git a/src/Nncase.Core/IValue.cs b/src/Nncase.Core/IValue.cs
index 9089817fa5..f58cd1272b 100644
--- a/src/Nncase.Core/IValue.cs
+++ b/src/Nncase.Core/IValue.cs
@@ -92,14 +92,18 @@ public static TupleValue FromTensors(params Tensor[] tensors)
/// Created value.
public static IValue FromConst(Const @const)
{
- if (@const is TensorConst tc)
+ switch (@const)
{
- return FromTensor(tc.Value);
- }
- else
- {
- var tpc = (TupleConst)@const;
- return tpc.Value;
+ case TensorConst tc:
+ return FromTensor(tc.Value);
+ case TupleConst tpc:
+ return tpc.Value;
+ case ShapeConst spc:
+ return new ShapeValue(spc.Value.ToArray());
+ case DimensionConst dc:
+ return new DimensionValue(dc.Value);
+ default:
+ throw new ArgumentOutOfRangeException(nameof(@const));
}
}
}
@@ -255,6 +259,109 @@ public override string ToString()
}
}
+public sealed class DimensionValue : IValue, IEquatable
+{
+ private readonly Dimension _value;
+
+ public DimensionValue(Dimension value)
+ {
+ _value = value;
+ }
+
+ public IRType Type => new TensorType(DataTypes.Int64, Shape.Scalar);
+
+ public int Count => 0;
+
+ public Dimension Dimension => _value;
+
+ public IValue this[int index] => throw new NotSupportedException("scalar can't index");
+
+ public Tensor AsTensor() => throw new NotImplementedException();
+
+ public Tensor[] AsTensors() => throw new NotImplementedException();
+
+ public bool Equals(DimensionValue? other) => EqualityComparer.Default.Equals(_value, other?._value);
+
+ public override int GetHashCode() => EqualityComparer.Default.GetHashCode(_value);
+
+ public override bool Equals(object? obj)
+ {
+ return Equals(obj as DimensionValue);
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ yield break;
+ }
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+}
+
+public sealed class ShapeValue : IValue, IEquatable
+{
+ private readonly Dimension[] _values;
+
+ public ShapeValue(params Dimension[] values)
+ {
+ _values = values;
+ }
+
+ public ShapeValue(IEnumerable values)
+ {
+ var dims = new List();
+ foreach (var item in values)
+ {
+ if (item is DimensionValue dimValue)
+ {
+ dims.Add(dimValue.Dimension);
+ }
+ else if (item is ShapeValue shapeValue)
+ {
+ dims.AddRange(shapeValue._values);
+ }
+ else
+ {
+ throw new NotSupportedException("only support dimension/shape value for constructor");
+ }
+ }
+
+ _values = dims.ToArray();
+ }
+
+ public IRType Type => new TensorType(DataTypes.Int64, new[] { _values.Length });
+
+ public int Count => _values.Length;
+
+ public Span Dimensions => _values;
+
+ public IValue this[int index] => new DimensionValue(_values[index]);
+
+ public Tensor AsTensor() => throw new NotImplementedException();
+
+ public Tensor[] AsTensors() => throw new NotImplementedException();
+
+ public bool Equals(ShapeValue? other) => StructuralComparisons.StructuralEqualityComparer.Equals(_values, other?._values);
+
+ public override int GetHashCode() => StructuralComparisons.StructuralEqualityComparer.GetHashCode(_values);
+
+ public override bool Equals(object? obj)
+ {
+ return Equals(obj as ShapeValue);
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ foreach (var item in _values)
+ {
+ yield return new DimensionValue(item);
+ }
+ }
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+
+ public override string ToString() => $"[{string.Join(",", _values.AsValueEnumerable().Select(v => v.ToString()))}]";
+}
+
///
/// Tuple value.
///
diff --git a/src/Nncase.Core/Utilities/IRUtility.cs b/src/Nncase.Core/Utilities/IRUtility.cs
new file mode 100644
index 0000000000..1ce420031d
--- /dev/null
+++ b/src/Nncase.Core/Utilities/IRUtility.cs
@@ -0,0 +1,131 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+namespace Nncase.Utilities;
+
+public static class IRUtility
+{
+ ///
+ /// find the reshape's shape transform matrix.
+ ///
+ /// input shape.
+ /// new shape.
+ /// mat [new shape dim, old shpe dim].
+ /// bool.
+ public static bool TryGetShapeMapMatrix(int[] inShape, int[] newShape, out int[,] mat)
+ {
+ int ProdIn(int[,] cmat, int i)
+ {
+ var prod = 1;
+ for (int j = 0; j < inShape.Length; j++)
+ {
+ var v = cmat[i, j] * inShape[j];
+ if (v != 0)
+ {
+ prod *= v;
+ }
+ }
+
+ return prod;
+ }
+
+ int ProdOut(int[,] cmat, int j)
+ {
+ var prod = 1;
+ for (int i = 0; i < newShape.Length; i++)
+ {
+ var v = cmat[i, j] * newShape[i];
+ if (v != 0)
+ {
+ prod *= v;
+ }
+ }
+
+ return prod;
+ }
+
+ mat = new int[newShape.Length, inShape.Length];
+ int i = 0, j = 0;
+ var paths = new List<(int, int)>();
+ while (i < newShape.Length && j < inShape.Length)
+ {
+ if (paths.IndexOf((i, j)) != -1)
+ {
+ return false;
+ }
+
+ mat[i, j] = 1;
+ paths.Add((i, j));
+ var inDim = ProdIn(mat, i);
+ var outDim = ProdOut(mat, j);
+ switch (inDim - outDim)
+ {
+ case 0:
+ i++; j++;
+ break;
+ case < 0:
+ j++;
+ break;
+ case > 0:
+ if (inDim % newShape[i] == 0)
+ {
+ i++;
+ }
+ else
+ {
+ mat[i, j] = 0;
+ j--;
+ paths.RemoveAt(paths.Count - 1);
+ }
+
+ break;
+ }
+ }
+
+ return i == newShape.Length && j == inShape.Length;
+ }
+
+ ///
+ /// convert the mapping matrix as a dictionary.
+ /// the key is in dim, value is not dim.
+ ///
+ /// mat.
+ /// dict.
+ public static (Dictionary> Forward, Dictionary> Backward) ShapeMapMatrixAsDict(int[,] mat)
+ {
+ var forward = new Dictionary>();
+ var backward = new Dictionary>();
+ for (int i = 0; i < mat.GetLength(0); i++)
+ {
+ for (int j = 0; j < mat.GetLength(1); j++)
+ {
+ if (mat[i, j] == 0)
+ {
+ continue;
+ }
+
+ if (!forward.TryGetValue(j, out var l1))
+ {
+ l1 = new() { i };
+ forward.Add(j, l1);
+ }
+ else
+ {
+ l1.Add(i);
+ }
+
+ if (!backward.TryGetValue(i, out var l2))
+ {
+ l2 = new() { j };
+ backward.Add(i, l2);
+ }
+ else
+ {
+ l2.Add(j);
+ }
+ }
+ }
+
+ return (forward, backward);
+ }
+}
diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
index d04de6aec3..fad4ab909e 100644
--- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
+++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
@@ -365,6 +365,8 @@ protected override string VisitConst(Const expr)
{
TensorConst tc => tc.Value.Shape.Size <= 8 ? tc.Value.GetArrayString(false) : string.Empty,
TupleConst => string.Empty,
+ ShapeConst sc => VisitShape(sc.Value),
+ DimensionConst dc => VisitDimension(dc.Value),
_ => throw new ArgumentOutOfRangeException(nameof(expr)),
};
valueStr = valueStr != string.Empty ? " : " + valueStr : string.Empty;
diff --git a/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs b/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs
index 0774e2abc6..7a04dfd944 100644
--- a/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs
+++ b/src/Nncase.EGraph/CostModel/EGraphCostEvaluator.cs
@@ -119,8 +119,7 @@ private void TryEvaluateAll()
return enode.Expr switch
{
Var var => Visit(enode, var),
- TensorConst con => Visit(enode, con),
- TupleConst con => Visit(enode, con),
+ Const con => Visit(enode, con),
Function func => Visit(enode, func),
Call call => Visit(enode, call, returnType),
IR.Tuple tuple => Visit(enode, tuple),
@@ -138,7 +137,7 @@ private Cost Visit(ENode enode, Var var)
return VisitLeaf(enode, () => Cost.Zero);
}
- private Cost Visit(ENode enode, TensorConst tc)
+ private Cost Visit(ENode enode, Const @const)
{
return VisitLeaf(enode, () => Cost.Zero);
}
@@ -153,11 +152,6 @@ private Cost Visit(ENode enode, Op op)
return Visit(enode, costs => Cost.Zero);
}
- private Cost? Visit(ENode enode, TupleConst tc)
- {
- return Visit(enode, costs => costs.Sum());
- }
-
private Cost? Visit(ENode enode, IR.Tuple tuple)
{
return Visit(enode, costs => _accumulate ? costs.Sum() : Cost.Zero);
diff --git a/src/Nncase.EGraph/Passes/EGraphExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractor.cs
index cbe4249199..42c283042a 100644
--- a/src/Nncase.EGraph/Passes/EGraphExtractor.cs
+++ b/src/Nncase.EGraph/Passes/EGraphExtractor.cs
@@ -621,6 +621,12 @@ public Expr Visit(EClass root)
case Marker mk:
expr = mk.With(target: children[0], attribute: children[1], metadata: mk.Metadata);
break;
+ case ShapeConst sc:
+ expr = sc;
+ break;
+ case DimensionConst dc:
+ expr = dc;
+ break;
default:
throw new NotSupportedException(enode.Expr.GetType().Name);
}
diff --git a/src/Nncase.EGraph/Passes/EGraphPrinter.cs b/src/Nncase.EGraph/Passes/EGraphPrinter.cs
index 17f9cee6fd..2b3c13b29d 100644
--- a/src/Nncase.EGraph/Passes/EGraphPrinter.cs
+++ b/src/Nncase.EGraph/Passes/EGraphPrinter.cs
@@ -258,6 +258,8 @@ protected override string VisitConst(Const expr)
{
TensorConst tc => tc.Value.Shape.Size <= 8 ? tc.Value.GetArrayString(false) : string.Empty,
TupleConst => string.Empty,
+ ShapeConst sc => VisitShape(sc.Value),
+ DimensionConst dc => VisitDimension(dc.Value),
_ => throw new ArgumentOutOfRangeException(nameof(expr)),
};
valueStr = valueStr != string.Empty ? " : " + valueStr : string.Empty;
@@ -282,5 +284,23 @@ protected override string VisitOp(Op expr)
protected override string VisitTuple(IR.Tuple expr) => "Tuple";
protected override string VisitNone(None expr) => "None";
+
+ private string VisitShape(Shape shape) =>
+ shape.Kind switch
+ {
+ ShapeKind.Invalid => "Invalid",
+ ShapeKind.Unranked => "Unranked",
+ _ => $"[{string.Join(',', shape.Select(VisitDimension))}]",
+ };
+
+ private string VisitDimension(Dimension dimension) =>
+ dimension.Kind switch
+ {
+ DimensionKind.Any => "any",
+ DimensionKind.Fixed => dimension.FixedValue.ToString(),
+ DimensionKind.Unknown => dimension.Value is Var var ? $"%{var.Name}" : "?",
+ _ => throw new NotSupportedException(dimension.Kind.ToString()),
+ };
+
}
}
diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs
index 4343cd168b..3d78be5d7b 100755
--- a/src/Nncase.Evaluator/Math/Binary.cs
+++ b/src/Nncase.Evaluator/Math/Binary.cs
@@ -79,45 +79,61 @@ public static IRType CheckSBP(BinaryOp op, TensorType tensorType, DistributedTyp
///
public IValue Visit(IEvaluateContext context, Binary binary)
{
- var lhs = context.GetArgumentValueAsTensor(binary, Binary.Lhs);
- var rhs = context.GetArgumentValueAsTensor(binary, Binary.Rhs);
- if (lhs.Shape.IsScalar && rhs.Shape.IsScalar)
+ var lhsValue = context.GetArgumentValue(binary, Binary.Lhs);
+ var rhsValue = context.GetArgumentValue(binary, Binary.Rhs);
+ switch (lhsValue, rhsValue)
{
- if (lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32)
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else if (lhs.ElementType == DataTypes.Int64 && rhs.ElementType == DataTypes.Int64)
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else if (lhs.ElementType == DataTypes.Float32 && rhs.ElementType == DataTypes.Float32)
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else if (lhs.ElementType == DataTypes.Boolean && rhs.ElementType == DataTypes.Boolean)
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else if (lhs.ElementType == DataTypes.UInt32 && rhs.ElementType == DataTypes.UInt32)
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64))
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType)
- {
- return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
- }
- else
- {
+ case (TensorValue lhsTV, TensorValue rhsTV):
+ var lhs = lhsTV.AsTensor();
+ var rhs = rhsTV.AsTensor();
+ if (lhs.Shape.IsScalar && rhs.Shape.IsScalar)
+ {
+ if (lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else if (lhs.ElementType == DataTypes.Int64 && rhs.ElementType == DataTypes.Int64)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else if (lhs.ElementType == DataTypes.Float32 && rhs.ElementType == DataTypes.Float32)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else if (lhs.ElementType == DataTypes.Boolean && rhs.ElementType == DataTypes.Boolean)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else if (lhs.ElementType == DataTypes.UInt32 && rhs.ElementType == DataTypes.UInt32)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64))
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else
+ {
+ return Ort_compute(binary, lhs, rhs);
+ }
+ }
+
return Ort_compute(binary, lhs, rhs);
- }
+ case (DimensionValue { Dimension: { IsFixed: true } } lhsDV, TensorValue { Type: TensorType { Shape: { IsScalar: true } } } rhsTV):
+ return Value.FromTensor(Compute(binary.BinaryOp, lhsDV.Dimension.FixedValue, rhsTV.AsTensor().ToScalar()));
+ case (TensorValue { Type: TensorType { Shape: { IsScalar: true } } } lhsTV, DimensionValue { Dimension: { IsFixed: true } } rhsDV):
+ return Value.FromTensor(Compute(binary.BinaryOp, lhsTV.AsTensor().ToScalar(), rhsDV.Dimension.FixedValue));
+ case (DimensionValue { Dimension: { IsFixed: true } } lhsDV, DimensionValue { Dimension: { IsFixed: true } } rhsDV):
+ return Value.FromTensor(Compute(binary.BinaryOp, lhsDV.Dimension.FixedValue, rhsDV.Dimension.FixedValue));
+ default:
+ break;
}
- return Ort_compute(binary, lhs, rhs);
+ throw new NotSupportedException($"binary notsupport {lhsValue} {rhsValue}");
}
///
diff --git a/src/Nncase.Evaluator/Math/Compare.cs b/src/Nncase.Evaluator/Math/Compare.cs
index 69322ec184..0d34dc5b4b 100644
--- a/src/Nncase.Evaluator/Math/Compare.cs
+++ b/src/Nncase.Evaluator/Math/Compare.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
+using System.Linq;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.IR.Math;
@@ -67,25 +68,31 @@ public static IRType CheckSBP(TensorType tensorType, DistributedType a, Distribu
///
public IValue Visit(IEvaluateContext context, Compare target)
{
- var lhs = context.GetArgumentValueAsTensor(target, Compare.Lhs);
- var rhs = context.GetArgumentValueAsTensor(target, Compare.Rhs);
- if (lhs.Shape.IsScalar && rhs.Shape.IsScalar && lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32)
+ var lhsValue = context.GetArgumentValue(target, Compare.Lhs);
+ var rhsValue = context.GetArgumentValue(target, Compare.Rhs);
+ switch (lhsValue, rhsValue)
{
- return Value.FromTensor(Tensor.FromScalar(Compute(target.CompareOp, lhs.ToScalar(), rhs.ToScalar())));
+ case (TensorValue lhsTV, TensorValue rhsTV):
+ var lhs = lhsTV.AsTensor();
+ var rhs = rhsTV.AsTensor();
+ if (lhs.Shape.IsScalar && rhs.Shape.IsScalar && lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32)
+ {
+ return Value.FromTensor(Tensor.FromScalar(Compute(target.CompareOp, lhs.ToScalar(), rhs.ToScalar())));
+ }
+ else
+ {
+ return Compute(target.CompareOp, lhs.ToOrtTensor(), rhs.ToOrtTensor());
+ }
+
+ case (ShapeValue lhsTV, TensorValue rhsTV):
+ return Value.FromTensor(Tensor.FromArray(lhsTV.Dimensions.ToArray().Zip(rhsTV.AsTensor().ToArray()).Select(p => Compute(target.CompareOp, p.First, p.Second)).ToArray()));
+ case (TensorValue lhsTV, ShapeValue rhsTV):
+ return Value.FromTensor(Tensor.FromArray(lhsTV.AsTensor().ToArray().Zip(rhsTV.Dimensions.ToArray()).Select(p => Compute(target.CompareOp, p.First, p.Second)).ToArray()));
+ default:
+ break;
}
- var a = context.GetOrtArgumentValue(target, Compare.Lhs);
- var b = context.GetOrtArgumentValue(target, Compare.Rhs);
- return target.CompareOp switch
- {
- CompareOp.Equal => OrtKI.Equal(a, b).ToValue(),
- CompareOp.LowerOrEqual => OrtKI.LessOrEqual(a, b).ToValue(),
- CompareOp.GreaterOrEqual => OrtKI.GreaterOrEqual(a, b).ToValue(),
- CompareOp.GreaterThan => OrtKI.Greater(a, b).ToValue(),
- CompareOp.LowerThan => OrtKI.Less(a, b).ToValue(),
- CompareOp.NotEqual => OrtKI.Not(OrtKI.Equal(a, b)).ToValue(),
- _ => throw new ArgumentOutOfRangeException(target.CompareOp.ToString()),
- };
+ throw new NotSupportedException();
}
///
@@ -149,7 +156,41 @@ public Expr Visit(IShapeEvaluateContext context, Compare target)
return ShapeExprUtility.BroadcastShape(lhs, rhs);
}
- private bool Compute(CompareOp op, int a, int b) => op switch
+ private bool Compute(CompareOp op, Dimension a, long b)
+ {
+ return (a, b) switch
+ {
+ (Dimension { IsFixed: true } da, _) => Compute(op, da.FixedValue, b),
+ (Dimension { IsFixed: false } da, _) => Compute(op, da.Value, b),
+ };
+ }
+
+ private bool Compute(CompareOp op, long a, Dimension b)
+ {
+ return (a, b) switch
+ {
+ (_, Dimension { IsFixed: true } db) => Compute(op, a, db.FixedValue),
+ (_, Dimension { IsFixed: false } db) => Compute(op, a, db.Value),
+ };
+ }
+
+ private IValue Compute(CompareOp op, OrtKISharp.Tensor a, OrtKISharp.Tensor b)
+ {
+ return op switch
+ {
+ CompareOp.Equal => OrtKI.Equal(a, b).ToValue(),
+ CompareOp.LowerOrEqual => OrtKI.LessOrEqual(a, b).ToValue(),
+ CompareOp.GreaterOrEqual => OrtKI.GreaterOrEqual(a, b).ToValue(),
+ CompareOp.GreaterThan => OrtKI.Greater(a, b).ToValue(),
+ CompareOp.LowerThan => OrtKI.Less(a, b).ToValue(),
+ CompareOp.NotEqual => OrtKI.Not(OrtKI.Equal(a, b)).ToValue(),
+ _ => throw new ArgumentOutOfRangeException(op.ToString()),
+ };
+ }
+
+ private bool Compute(CompareOp op, T a, T b)
+ where T : System.Numerics.IEqualityOperators, System.Numerics.IComparisonOperators
+ => op switch
{
CompareOp.Equal => a == b,
CompareOp.LowerOrEqual => a <= b,
@@ -160,6 +201,20 @@ public Expr Visit(IShapeEvaluateContext context, Compare target)
_ => throw new ArgumentOutOfRangeException(nameof(op)),
};
+ private bool Compute(CompareOp op, Expr a, long b)
+ => (op, a, b) switch
+ {
+ (CompareOp.Equal, Expr, -1) => false,
+ _ => throw new ArgumentOutOfRangeException(nameof(op)),
+ };
+
+ private bool Compute(CompareOp op, long a, Expr b)
+ => (op, a, b) switch
+ {
+ (CompareOp.Equal, -1, Expr) => false,
+ _ => throw new ArgumentOutOfRangeException(nameof(op)),
+ };
+
private IRType Visit(TensorType lhs, TensorType rhs)
{
var broadcastType = TypeInference.BroadcastType(lhs, rhs);
diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs
index 0a30397101..3c74ad6755 100644
--- a/src/Nncase.Evaluator/Tensors/Concat.cs
+++ b/src/Nncase.Evaluator/Tensors/Concat.cs
@@ -23,11 +23,45 @@ public class ConcatEvaluator : IEvaluator, ITypeInferencer, ICos
IShapeEvaluator, IMetricEvaluator
{
///
- public IValue Visit(IEvaluateContext context, Concat cat)
+ public IValue Visit(IEvaluateContext context, Concat target)
{
- var inputs = context.GetArgumentValueAsTensors(cat, Concat.Input);
- var axis = cat.Axis;
- return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue();
+ var inputValue = context.GetArgumentValue(target, Concat.Input);
+ var axis = target.Axis;
+ switch (inputValue)
+ {
+ case TupleValue tpv:
+ if (tpv.All(v => v is TensorValue))
+ {
+ var inputs = tpv.AsTensors();
+ return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue();
+ }
+ else if (tpv.Any(v => v is ShapeValue))
+ {
+ var dims = new List();
+ foreach (var fv in tpv)
+ {
+ switch (fv)
+ {
+ case TensorValue ftv:
+ dims.Add(new Dimension(ftv.AsTensor().Cast()[axis]));
+ break;
+ case ShapeValue fsv:
+ dims.Add(fsv.Dimensions[axis]);
+ break;
+ default:
+ throw new ArgumentOutOfRangeException(nameof(target), "ShapeValue's field not support");
+ }
+ }
+
+ return new ShapeValue(dims.ToArray());
+ }
+
+ break;
+ default:
+ break;
+ }
+
+ throw new ArgumentOutOfRangeException(nameof(target));
}
///
diff --git a/src/Nncase.Evaluator/Tensors/Gather.cs b/src/Nncase.Evaluator/Tensors/Gather.cs
index 3aebcbd9b4..2d6e9e0965 100644
--- a/src/Nncase.Evaluator/Tensors/Gather.cs
+++ b/src/Nncase.Evaluator/Tensors/Gather.cs
@@ -22,10 +22,33 @@ public class GatherEvaluator : IEvaluator, ITypeInferencer, ICos
///
public IValue Visit(IEvaluateContext context, Gather gather)
{
- var input = context.GetOrtArgumentValue(gather, Gather.Input);
+ var inputValue = context.GetArgumentValue(gather, Gather.Input);
+ var indexValue = context.GetArgumentValue(gather, Gather.Index);
var axis = gather.Axis;
- var index = context.GetOrtArgumentValue(gather, Gather.Index);
- return OrtKI.Gather(input, index, axis).ToValue();
+ switch (inputValue, indexValue)
+ {
+ case (_, TensorValue indexTValue):
+ if (inputValue is TensorValue inputTValue)
+ {
+ return OrtKI.Gather(inputTValue.AsTensor().ToOrtTensor(), indexTValue.AsTensor().ToOrtTensor(), axis).ToValue();
+ }
+ else if (inputValue is ShapeValue inputSValue && axis == 0)
+ {
+ var indexTensor = indexTValue.AsTensor();
+ if (!indexTensor.Shape.IsScalar)
+ {
+ throw new NotSupportedException("Gather ShapeConst the index must be scalar!");
+ }
+
+ return inputSValue[indexTensor.ToScalar()];
+ }
+
+ break;
+ default:
+ break;
+ }
+
+ throw new NotSupportedException();
}
///
diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs
index ac28fcbfe3..fb1959dc2d 100644
--- a/src/Nncase.Evaluator/Tensors/Reshape.cs
+++ b/src/Nncase.Evaluator/Tensors/Reshape.cs
@@ -97,6 +97,7 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i
var shapeDims = new Shape(Enumerable.Range(0, rank).Select(i => (Dimension)shape[i]).ToArray());
var outputShape = new Dimension[rank];
+ // todo use egraph simplify.
var minus1Dim = FixedAndDynamicDimension.Abs(input.Shape.ProdFixedAndDynamic() / shapeDims.ProdFixedAndDynamic());
for (var i = 0; i < rank; i++)
{
@@ -107,13 +108,21 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i
}
else
{
- var outputDim = ShapeExprUtility.If(
- Equal(shapeDim.Value, -1L),
- (shapeDim, minus1Dim) => minus1Dim,
- (shapeDim, minus1Dim) => shapeDim,
- shapeDim.Value,
- minus1Dim.ToExpr());
- outputShape[i] = outputDim;
+ switch (shapeDim)
+ {
+ case Dimension { Value: Var }:
+ outputShape[i] = shapeDim;
+ break;
+ default:
+ var outputDim = ShapeExprUtility.If(
+ Equal(shapeDim.Value, -1L),
+ (shapeDim, minus1Dim) => minus1Dim,
+ (shapeDim, minus1Dim) => shapeDim,
+ shapeDim.Value,
+ minus1Dim.ToExpr());
+ outputShape[i] = outputDim;
+ break;
+ }
}
}
diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs
index 7b16051bf2..b1763790a3 100644
--- a/src/Nncase.Evaluator/Tensors/Slice.cs
+++ b/src/Nncase.Evaluator/Tensors/Slice.cs
@@ -30,13 +30,55 @@ public class SliceEvaluator : IEvaluator, ITypeInferencer, ICostEv
///
public IValue Visit(IEvaluateContext context, Slice sl)
{
- var input = context.GetOrtArgumentValue(sl, Slice.Input);
- var begins = context.GetInt64OrtTensorArgumentValue(sl, Slice.Begins);
- var ends = context.GetInt64OrtTensorArgumentValue(sl, Slice.Ends);
- var axes = context.GetInt64OrtTensorArgumentValue(sl, Slice.Axes);
- var strides = context.GetInt64OrtTensorArgumentValue(sl, Slice.Strides);
- var sliced = OrtKI.Slice(input, begins, ends, axes, strides);
- return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType));
+ var inputValue = context.GetArgumentValue(sl, Slice.Input);
+ var beginsValue = context.GetArgumentValue(sl, Slice.Begins);
+ var endsValue = context.GetArgumentValue(sl, Slice.Ends);
+ var axesValue = context.GetArgumentValue(sl, Slice.Axes);
+ var stridesValue = context.GetArgumentValue(sl, Slice.Strides);
+ switch (inputValue, beginsValue, endsValue, axesValue, stridesValue)
+ {
+ case (_, TensorValue beginsTValue, TensorValue endsTValue, TensorValue axesTValue, TensorValue stridesTValue):
+ var beginsTensor = beginsTValue.AsTensor();
+ var endsTensor = endsTValue.AsTensor();
+ var axesTensor = axesTValue.AsTensor();
+ var stridesTensor = stridesTValue.AsTensor();
+ if (inputValue is ShapeValue inputSValue && beginsTensor.Shape.Rank == 1 && endsTensor.Shape.Rank == 1 && axesTensor.Shape.Rank == 1 && stridesTensor.Shape.Rank == 1)
+ {
+ // var input = inputShapeValue.AsTensor().Cast().ToOrtTensor();
+ var begins = beginsTensor.ToScalar();
+ var ends = endsTensor.ToScalar();
+ var axes = axesTensor.ToScalar();
+ if (axes != 0)
+ {
+ throw new NotSupportedException("slice ShapeConst Axes != 0");
+ }
+
+ var strides = stridesTensor.ToScalar();
+ var sliced = new List();
+ for (long i = begins; i < ends; i += strides)
+ {
+ sliced.Add(inputSValue[checked((int)i)]);
+ }
+
+ return new ShapeValue(sliced);
+ }
+ else if (inputValue is TensorValue inputTValue)
+ {
+ var input = inputTValue.AsTensor().Cast().ToOrtTensor();
+ var begins = beginsTensor.Cast().ToOrtTensor();
+ var ends = endsTensor.Cast().ToOrtTensor();
+ var axes = axesTensor.Cast().ToOrtTensor();
+ var strides = stridesTensor.Cast().ToOrtTensor();
+ var sliced = OrtKI.Slice(input, begins, ends, axes, strides);
+ return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType));
+ }
+
+ break;
+ default:
+ break;
+ }
+
+ throw new NotSupportedException("input value is neither shapevalue or tensorvalue");
}
///
diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs
index 76e1c7ad5b..c87c93d0e3 100644
--- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs
+++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs
@@ -19,9 +19,35 @@ public class UnsqueezeEvaluator : IEvaluator, ITypeInferencer
public IValue Visit(IEvaluateContext context, Unsqueeze unSqueeze)
{
- var input = context.GetOrtArgumentValue(unSqueeze, Unsqueeze.Input);
- var axes = context.GetInt64OrtTensorArgumentValue(unSqueeze, Unsqueeze.Dim);
- return Value.FromTensor(OrtKI.Unsqueeze(input, axes).ToTensor(context.CurrentCall.CheckedTensorType));
+ var inputValue = context.GetArgumentValue(unSqueeze, Unsqueeze.Input);
+ var axesValue = context.GetArgumentValue(unSqueeze, Unsqueeze.Dim);
+
+ switch (inputValue, axesValue)
+ {
+ case (_, TensorValue axesTValue):
+ var axesTensor = axesTValue.AsTensor();
+ if (inputValue is TensorValue inputTValue)
+ {
+ var input = inputTValue.AsTensor().ToOrtTensor();
+ var axes = axesTensor.Cast().ToOrtTensor();
+ return Value.FromTensor(OrtKI.Unsqueeze(input, axes).ToTensor(context.CurrentCall.CheckedTensorType));
+ }
+ else if (inputValue is DimensionValue inputDValue)
+ {
+ if (axesTensor.Shape.Rank > 1 || axesTensor.ToScalar() != 0)
+ {
+ throw new NotSupportedException("only support scalar dim when input is DimensionValue!");
+ }
+
+ return new ShapeValue(new[] { inputDValue });
+ }
+
+ break;
+ default:
+ break;
+ }
+
+ throw new NotSupportedException();
}
///
diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs
index d80c9fcd19..7258d271e4 100644
--- a/src/Nncase.Evaluator/Tensors/Where.cs
+++ b/src/Nncase.Evaluator/Tensors/Where.cs
@@ -23,24 +23,41 @@ public class WhereEvaluator : IEvaluator, ITypeInferencer, ICostEv
///
public IValue Visit(IEvaluateContext context, Where where)
{
- var xt = context.GetArgumentValueAsTensor(where, Where.X);
- var yt = context.GetArgumentValueAsTensor(where, Where.Y);
- if (where.IsTfWhere)
+ var condValue = context.GetArgumentValue(where, Where.Cond);
+ var xValue = context.GetArgumentValue(where, Where.X);
+ var yValue = context.GetArgumentValue(where, Where.Y);
+ switch (condValue, xValue, yValue)
{
- var condTensor = context.GetArgumentValueAsTensor(where, Where.Cond);
- if (condTensor.Rank > 1)
- {
- throw new NotImplementedException();
- }
-
- var result = condTensor.Select((b, i) => (b, i)).Where(t => t.b).Select(t => (long)t.i).ToArray();
- return Value.FromTensor(Tensor.From(result, new Shape(result.Length, condTensor.Rank)));
+ case (TensorValue condTV, TensorValue xTV, _):
+ if (yValue is TensorValue yTV)
+ {
+ if (where.IsTfWhere)
+ {
+ var condTensor = condTV.AsTensor().Cast();
+ if (condTensor.Rank > 1)
+ {
+ throw new NotImplementedException();
+ }
+
+ var result = condTensor.Select((b, i) => (b, i)).Where(t => t.b).Select(t => (long)t.i).ToArray();
+ return Value.FromTensor(Tensor.From(result, new Shape(result.Length, condTensor.Rank)));
+ }
+ else
+ {
+ return OrtKI.Where(condTV.AsTensor().ToOrtTensor(), xTV.AsTensor().ToOrtTensor(), yTV.AsTensor().ToOrtTensor()).ToValue();
+ }
+ }
+ else if (yValue is ShapeValue ySV)
+ {
+ return new ShapeValue(condTV.AsTensor().Cast().Zip(xValue.AsTensor().Cast().Zip(ySV.Dimensions.ToArray())).Select(tp => tp.First ? new Dimension(tp.Second.First) : tp.Second.Second).ToArray());
+ }
+
+ break;
+ default:
+ break;
}
- var cond = context.GetOrtArgumentValue(where, Where.Cond);
- var x = context.GetOrtArgumentValue(where, Where.X);
- var y = context.GetOrtArgumentValue(where, Where.Y);
- return OrtKI.Where(cond, x, y).ToValue();
+ throw new NotSupportedException();
}
///
diff --git a/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs b/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs
index 1d6f507d3e..bdcf65f8c7 100644
--- a/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs
+++ b/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs
@@ -48,10 +48,39 @@ private Const GetReplace(Call call, IReadOnlyList constArgs)
public partial class FoldShapeOf : RewriteRule
{
///
- public override CallPattern Pattern { get; } = IsShapeOf(IsWildcard("wc") with { TypePattern = HasFixedShape() });
+ public override CallPattern Pattern { get; } = IsShapeOf(IsWildcard("wc") with { TypePattern = HasRank() });
private Const GetReplace(Expr wc)
{
- return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray());
+ if (wc.CheckedShape.IsFixed)
+ {
+ return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray());
+ }
+ else
+ {
+ return new ShapeConst(wc.CheckedShape);
+ }
}
}
+
+///
+/// concat([1, seq_len, 2, 3], 0) => tuple([1,seq_len,2,3]).
+///
+// [RuleGenerator]
+// public partial class FoldConcatShape : RewriteRule
+// {
+// ///
+// public override CallPattern Pattern { get; } = IsConcat("concat", _ => true, IsTuple("tuple", IsVArgsRepeat("fileds", exprs =>
+// {
+// var patterns = new Pattern[exprs.Length];
+// for (int i = 0; i < exprs.Length; i++)
+// {
+// patterns[i] = IsWildcard($"input_{i}") with { TypePattern = IsIntegralScalar() };
+// }
+// return patterns;
+// })));
+// private Const GetReplace(Expr wc)
+// {
+// return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray());
+// }
+// }
diff --git a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs
index 261543a43b..0e290ae13c 100644
--- a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs
+++ b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs
@@ -326,6 +326,23 @@ public void TestBroadcastInfer2()
Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimUnk1, 8192 }), result);
}
+ [Fact]
+ public void TestReshapeInfer()
+ {
+ var dimVar = new Var("seq_len", new TensorType(DataTypes.Int64, Shape.Scalar));
+ var dimC = new Dimension(dimVar);
+ var a = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 128 }));
+ var constShape = new ShapeConst(new[] { 1, dimC, 2, 64 });
+ var reshape = Reshape(a, constShape);
+ var result = reshape.CheckedType;
+ Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 2, 64 }), result);
+
+ var b = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 14, 64 }));
+ var reshapeb = Reshape(b, new ShapeConst(new[] { 1, dimC, -1 }));
+ var resultb = reshapeb.CheckedType;
+ Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 896 }), resultb);
+ }
+
private void CheckInferShape(Expr expr, params Dimension[] shapeDimensions)
{
CheckInferShape(expr, new Shape(shapeDimensions));