Skip to content

Latest commit

 

History

History
1067 lines (741 loc) · 29.7 KB

zlow.md

File metadata and controls

1067 lines (741 loc) · 29.7 KB

zlow.add (::onnx_mlir::zlow::ZLowAddOp)

ZLow add operation

ZLow operation to perform an add.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.avgpool2d (::onnx_mlir::zlow::ZLowAvgPool2DOp)

ZLow 2D average pooling operation

ZLow operation to perform 2D average pooling.

  • shape is a 1D MemRef (memref<6xi64>) whose items are:
    • 1st item: batch size
    • 2nd item: channel
    • 3rd item: height in
    • 4th item: width in
    • 5th item: height out
    • 6th item: width out
  • kernel_shape: 1D array of kernel height and width
  • strides: 1D array of stride height and width
  • padding_type: SAME_PADDING or VALID_PADDING.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
kernel_shape::mlir::ArrayAttr64-bit integer array attribute
strides::mlir::ArrayAttr64-bit integer array attribute
padding_type::mlir::StringAttrstring attribute

Operands:

Operand Description
input memref of dlfloat16 type values
shape memref of 64-bit signless integer values
output memref of dlfloat16 type values

zlow.batchnorm (::onnx_mlir::zlow::ZLowBatchNormOp)

ZLow batchnorm operation

ZLow operation to perform batchnorm.

  • shape is a 1D MemRef (memref<4xi64>) whose items are:
    • 1st item: batch size
    • 2nd item: height
    • 3rd item: width
    • 4th item: channel

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Operands:

Operand Description
input memref of dlfloat16 type values
A memref of dlfloat16 type values
B memref of dlfloat16 type values
shape memref of 64-bit signless integer values
output memref of dlfloat16 type values

zlow.conv2d (::onnx_mlir::zlow::ZLowConv2DOp)

ZLow 2D convolution operation

ZLow operation to perform 2D convolution.

  • shape is a 1D MemRef (memref<7xi64>) whose items are:
    • 1st item: batch size
    • 2nd item: channel in
    • 3rd item: height in
    • 4th item: width in
    • 5th item: channel out
    • 6th item: height out
    • 7th item: width out
  • kernel_shape: 1D array of kernel height and width
  • strides: 1D array of stride height and width
  • padding_type: SAME_PADDING or VALID_PADDING.
  • act_func: ACT_NONE or ACT_RELU.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
kernel_shape::mlir::ArrayAttr64-bit integer array attribute
strides::mlir::ArrayAttr64-bit integer array attribute
padding_type::mlir::StringAttrstring attribute
act_func::mlir::StringAttrstring attribute

Operands:

Operand Description
input memref of dlfloat16 type values
input_kernel memref of dlfloat16 type values
input_bias memref of dlfloat16 type values
shape memref of 64-bit signless integer values
output memref of dlfloat16 type values

zlow.dlf16_to_f32 (::onnx_mlir::zlow::ZLowConvertDLF16ToF32Op)

Convert a dlfloat16 value to a float32 value

This operation converts a dlfloat16 value to a float32 value.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
input dlfloat16 type

Results:

Result Description
output 32-bit float

zlow.vec_dlf16_to_f32 (::onnx_mlir::zlow::ZLowConvertDLF16ToF32VectorOp)

Convert dlfloat16 values to float32 values

This operation converts dlfloat16 values to float32 values.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
input vector of 16-bit float values of length 8

Results:

Result Description
output1 vector of 32-bit float values of length 4
output2 vector of 32-bit float values of length 4

zlow.f32_to_dlf16 (::onnx_mlir::zlow::ZLowConvertF32ToDLF16Op)

Convert a float32 value to a dlfloat16 value

This operation converts a float32 value to a dlfloat16 value.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
input 32-bit float

Results:

Result Description
output dlfloat16 type

zlow.vec_f32_to_dlf16 (::onnx_mlir::zlow::ZLowConvertF32ToDLF16VectorOp)

Convert float32 values to dlfloat16 values

This operation converts float32 values to dlfloat16 values.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
input1 vector of 32-bit float values of length 4
input2 vector of 32-bit float values of length 4

Results:

Result Description
output vector of 16-bit float values of length 8

zlow.div (::onnx_mlir::zlow::ZLowDivOp)

ZLow div operation

ZLow operation to perform a div.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.dummy (::onnx_mlir::zlow::ZLowDummyOp)

ZLow dummy operation that behaves like identity

ZLow operation to forward the input value to the output value. It will be removed if canonicalization is called.

Traits: MemRefsNormalizable

Operands:

Operand Description
input any type

Results:

Result Description
output any type

zlow.exp (::onnx_mlir::zlow::ZLowExpOp)

ZLow exp operation

ZLow operation to perform a exp.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.gru (::onnx_mlir::zlow::ZLowGRUOp)

ZLow gru operation

ZLow operation to perform a gru.

  • work_area: a 4K-aligned buffer.
  • shape is a 1D MemRef (memref<5xi64>) whose items are:;
    • 1st item: direction
    • 2nd item: timestep
    • 3rd item: batchSize
    • 4th item: featureSize
    • 5th item: hiddenSize
  • direction accepts "forward", "reverse", or "bidirectional"
  • return_all_steps: -1 returns all timesteps, 0: returns only the last timestep.
  • prev_layer for where input comes is "none", "uni", or "bidir"

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
direction::mlir::StringAttrstring attribute
return_all_steps::mlir::IntegerAttr64-bit signed integer attribute
prev_layer::mlir::StringAttrstring attribute

Operands:

Operand Description
input memref of dlfloat16 type values
h0 memref of dlfloat16 type values
input_weights memref of dlfloat16 type values
input_bias memref of dlfloat16 type values
hidden_weights memref of dlfloat16 type values
hidden_bias memref of dlfloat16 type values
work_area memref of 8-bit signless integer values
shape memref of 64-bit signless integer values
hn_output memref of dlfloat16 type values

zlow.gelu (::onnx_mlir::zlow::ZLowGeluOp)

ZLow gelu operation

ZLow operation to perform a gelu.

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.invsqrt (::onnx_mlir::zlow::ZLowInvSqrtOp)

ZLow invsqrt operation

ZLow operation to perform a invsqrt.

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.lstm (::onnx_mlir::zlow::ZLowLSTMOp)

ZLow lstm operation

ZLow operation to perform a lstm. work_area: a 4K-aligned buffer.

  • shape is a 1D MemRef (memref<5xi64>) whose items are:
    • 1st item: direction
    • 2nd item: timestep
    • 3rd item: batchSize
    • 4th item: featureSize
    • 5th item: hiddenSize
  • direction accepts "forward", "reverse", or "bidirectional"
  • return_all_steps: -1 returns all timesteps, 0: returns only the last timestep
  • prev_layer for where input comes is "none", "uni", or "bidir"

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
direction::mlir::StringAttrstring attribute
return_all_steps::mlir::IntegerAttr64-bit signed integer attribute
prev_layer::mlir::StringAttrstring attribute

Operands:

Operand Description
input memref of dlfloat16 type values
h0 memref of dlfloat16 type values
c0 memref of dlfloat16 type values
input_weights memref of dlfloat16 type values
input_bias memref of dlfloat16 type values
hidden_weights memref of dlfloat16 type values
hidden_bias memref of dlfloat16 type values
work_area memref of 8-bit signless integer values
shape memref of 64-bit signless integer values
hn_output memref of dlfloat16 type values
cf_output memref of dlfloat16 type values

zlow.leakyrelu (::onnx_mlir::zlow::ZLowLeakyReluOp)

ZLow leakyrelu operation

ZLow operation to perform a leakyrelu.

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
alpha::mlir::FloatAttr32-bit float attribute
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.log (::onnx_mlir::zlow::ZLowLogOp)

ZLow log operation

ZLow operation to perform a log.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.matmul (::onnx_mlir::zlow::ZLowMatMulOp)

ZLow matmul operation

ZLow operation to perform a matmul.

  • In case of unstacked: X(m, n) * Y(n, p) + Bias(p) shape is a 1D MemRef (memref<3xi64>) whose items are:
    • 1st item: m
    • 2nd item: n
    • 3rd item: p
  • In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) or broadcasting1: X(m, n) * Y(s, n, p) + Bias(s, p) or broadcasting23: X(s, m, n) * Y(n, p) + Bias(p) shape is a 1D MemRef (memref<4xi64>) whose items are:
    • 1st item: s
    • 2nd item: m
    • 3rd item: n
    • 4th item: p
  • is_bcast1: -1 broadcasting1, 0: no broadcasting1.
  • is_bcast23: -1 broadcasting23, 0: no broadcasting23.
  • is_stacked: -1 stacked, 0: unstacked.
  • transposeA: !0 transpose A, 0: do not transpose A.
  • transposeB: !0 transpose B, 0: do not transpose B.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
is_bcast1::mlir::IntegerAttr64-bit signed integer attribute
is_bcast23::mlir::IntegerAttr64-bit signed integer attribute
is_stacked::mlir::IntegerAttr64-bit signed integer attribute
transposeA::mlir::IntegerAttr64-bit signed integer attribute
transposeB::mlir::IntegerAttr64-bit signed integer attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
Bias memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.max (::onnx_mlir::zlow::ZLowMaxOp)

ZLow max operation

ZLow operation to perform a max.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.maxpool2d (::onnx_mlir::zlow::ZLowMaxPool2DOp)

ZLow 2D max pooling operation

ZLow operation to perform 2D max pooling.

  • shape is a 1D MemRef (memref<6xi64>) whose items are:
    • 1st item: batch size
    • 2nd item: channel
    • 3rd item: height in
    • 4th item: width in
    • 5th item: height out
    • 6th item: width out
  • kernel_shape: 1D array of kernel height and width
  • strides: 1D array of stride height and width
  • padding_type: SAME_PADDING or VALID_PADDING.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
kernel_shape::mlir::ArrayAttr64-bit integer array attribute
strides::mlir::ArrayAttr64-bit integer array attribute
padding_type::mlir::StringAttrstring attribute

Operands:

Operand Description
input memref of dlfloat16 type values
shape memref of 64-bit signless integer values
output memref of dlfloat16 type values

zlow.meanreduce2d (::onnx_mlir::zlow::ZLowMeanReduce2DOp)

ZLow 2D mean reduce operation

ZLow operation to perform 2D mean reduce.

  • shape is a 1D MemRef (memref<4xindex>) whose items are:;
    • 1st item: batch size": 1st dim of input
    • 2rd item: height": 2nd dim of input
    • 3th item: width": 3rd dim of input
    • 4nd item: channel": 4th dim of input

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Operands:

Operand Description
input memref of dlfloat16 type values
shape memref of 64-bit signless integer values
output memref of dlfloat16 type values

zlow.min (::onnx_mlir::zlow::ZLowMinOp)

ZLow min operation

ZLow operation to perform a min.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.mul (::onnx_mlir::zlow::ZLowMulOp)

ZLow mul operation

ZLow operation to perform a mul.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.quantizedMatmul (::onnx_mlir::zlow::ZLowQuantizedMatMulOp)

ZLow quantized matmul operation

ZLow operation to perform a matmul. work_area: a 4K-aligned buffer having the same layout as bias but dlfloat16 type.

  • In case of unstacked: X(m, n) * Y(n, p) + Bias(p) shape is a 1D MemRef (memref<3xi64>) whose items are:
    • 1st item: m
    • 2nd item: n
    • 3rd item: p
  • In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) shape is a 1D MemRef (memref<4xi64>) whose items are:
    • 1st item: s
    • 2nd item: m
    • 3rd item: n
    • 4th item: p
  • is_bcast: -1 broadcasting, 0: no broadcasting.
  • is_stacked: -1 stacked, 0: unstacked.
  • DequantizeOutput: -1 output is dequantized, 0: output is not dequantized.
  • PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed.

Values for q_type are "DLFLOAT16", "INT8", "WEIGHTS", "UNDEFINED".

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
x_q_type::mlir::StringAttrstring attribute
y_q_type::mlir::StringAttrstring attribute
bias_q_type::mlir::StringAttrstring attribute
out_q_type::mlir::StringAttrstring attribute
is_bcast::mlir::IntegerAttr64-bit signed integer attribute
is_stacked::mlir::IntegerAttr64-bit signed integer attribute
pre_computed_bias::mlir::IntegerAttr64-bit signed integer attribute
disable_clipping::mlir::IntegerAttr64-bit signed integer attribute
dequantize_output::mlir::IntegerAttr64-bit signed integer attribute

Operands:

Operand Description
X memref of dlfloat16 type or 8-bit signless integer values
x_rec_scale 0D memref of 32-bit float values
x_offset 0D memref of 32-bit float values
Y memref of dlfloat16 type or 8-bit signless integer values
y_rec_scale 0D memref of 32-bit float values
y_offset 0D memref of 32-bit float values
Bias memref of dlfloat16 type or 8-bit signless integer values
bias_rec_scale 0D memref of 32-bit float values
bias_offset 0D memref of 32-bit float values
work_area memref of dlfloat16 type or 8-bit signless integer values or none type
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type or 8-bit signless integer values
out_rec_scale 0D memref of 32-bit float values
out_offset 0D memref of 32-bit float values

zlow.quantizedStick (::onnx_mlir::zlow::ZLowQuantizedStickOp)

ZLow stick operation for quantization

"ZLow operation to perform a quantization stick." "Type is one of values: dlfloat16, int8, and weights."

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
q_type::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of 8-bit signless integer or 32-bit float values
rec_scale 0D memref of 32-bit float values
offset 0D memref of 32-bit float values
out memref of dlfloat16 type or 8-bit signless integer values

zlow.reducemax (::onnx_mlir::zlow::ZLowReduceMaxOp)

ZLow reducemax operation

ZLow operation to perform a reducemax.

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
op_type::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
work_area memref of 8-bit signless integer values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.reducemin (::onnx_mlir::zlow::ZLowReduceMinOp)

ZLow reducemin operation

ZLow operation to perform a reducemin.

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
op_type::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
work_area memref of 8-bit signless integer values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.relu (::onnx_mlir::zlow::ZLowReluOp)

ZLow relu operation

ZLow operation to perform a relu.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.sigmoid (::onnx_mlir::zlow::ZLowSigmoidOp)

ZLow sigmoid operation

ZLow operation to perform a sigmoid.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.softmax (::onnx_mlir::zlow::ZLowSoftmaxOp)

ZLow softmax operation

ZLow operation to perform a softmax. work_area: a 4K-aligned buffer. act_func: ACT_NONE or ACT_LOG.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
act_func::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
work_area memref of 8-bit signless integer values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.sqrt (::onnx_mlir::zlow::ZLowSqrtOp)

ZLow sqrt operation

ZLow operation to perform a sqrt.

Traits: MemRefsNormalizable

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.stickForGRU (::onnx_mlir::zlow::ZLowStickForGRUOp)

ZLow stick operation for GRU

ZLow operation to perform a stick for GRU. Variadic: list of pointers for input data to be transformed:

  • GRU concatenated: 3 data pointers, one for each input gate in (Z)update, Reset, Hidden, (ZRH) gate order.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
prev_layer::mlir::StringAttrstring attribute

Operands:

Operand Description
z_gate memref of 16-bit float or 32-bit float values
r_gate memref of 16-bit float or 32-bit float values
h_gate memref of 16-bit float or 32-bit float values
out memref of dlfloat16 type values

zlow.stickForLSTM (::onnx_mlir::zlow::ZLowStickForLSTMOp)

ZLow stick operation for LSTM

ZLow operation to perform a stick for LSTM. Variadic: list of pointers for input data to be transformed:

  • LSTM concatenated: 4 data pointers, one for each input gate in Forget, Input, Cell, Output (FICO) order.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
prev_layer::mlir::StringAttrstring attribute

Operands:

Operand Description
f_gate memref of 16-bit float or 32-bit float values
i_gate memref of 16-bit float or 32-bit float values
c_gate memref of 16-bit float or 32-bit float values
o_gate memref of 16-bit float or 32-bit float values
out memref of dlfloat16 type values

zlow.stick (::onnx_mlir::zlow::ZLowStickOp)

ZLow stick operation

"ZLow operation to perform a stick."

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
saturation::mlir::IntegerAttr64-bit signed integer attribute

Operands:

Operand Description
X memref of 16-bit float or 32-bit float values
Out memref of dlfloat16 type values

zlow.sub (::onnx_mlir::zlow::ZLowSubOp)

ZLow sub operation

ZLow operation to perform a sub.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Y memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.tanh (::onnx_mlir::zlow::ZLowTanhOp)

ZLow tanh operation

ZLow operation to perform a tanh.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
shape memref of 64-bit signless integer values
Out memref of dlfloat16 type values

zlow.unstick (::onnx_mlir::zlow::ZLowUnstickOp)

ZLow unstick operation

ZLow operation to perform a unstick.

Traits: MemRefsNormalizable

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute

Operands:

Operand Description
X memref of dlfloat16 type values
Out memref of 16-bit float or 32-bit float values