Skip to content

Commit

Permalink
flag for indices op
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 31, 2024
1 parent eb4f02e commit db2a000
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 25 deletions.
8 changes: 4 additions & 4 deletions examples/simple/gather_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

print('#1')
x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
w = slope.tensor([[1,0],[0,1]], dtype=slope.int64)
# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# w = slope.tensor([[1,0],[0,1]], dtype=slope.int64)
w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
# y = x.gather_nd(w,0)
# print(f"{y=}")
y = x.gather_nd(w,0)
print(f"{y=}")

@slope.jit
def f(x, w):
Expand Down
1 change: 1 addition & 0 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def annotate_sig(in_symvals, out_symvals):


class IREEBackend(Backend):
dtype_for_indices = dtypes.int64
dtype_map = {
dtypes.float32: np.dtypes.Float32DType(),
dtypes.uint8: np.dtypes.UInt8DType(),
Expand Down
43 changes: 25 additions & 18 deletions src/slope/backends/onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import onnx
import onnxruntime
import tempfile

import random

def annotate_shape(symval):
xdtype = symval.dtype.mlir
Expand All @@ -56,16 +56,7 @@ def annotate_sig(in_symvals, out_symvals):


class ONNXRuntimeBackend(Backend):
sess_options = onnxruntime.SessionOptions()
# Disable this flags, easily get nan
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL

# Other flags
# sess_options.log_severity_level = 3
# sess_options.use_deterministic_compute = True
# sess_options.intra_op_num_threads = 4
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
dtype_for_indices = dtypes.int64
dtype_map = {
slope.core.dtypes.float32: "float",
dtypes.uint8: "uint8",
Expand Down Expand Up @@ -100,6 +91,17 @@ class ONNXRuntimeBackend(Backend):
dtype_map_inv = {v: k for k, v in dtype_map.items()}
device_map_inv = {v: k for k, v in device_map.items()}

sess_options = onnxruntime.SessionOptions()
# Disable this flags, easily get nan
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
# Other flags
# sess_options.log_severity_level = 3
# sess_options.use_deterministic_compute = True
# sess_options.intra_op_num_threads = 4
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL


def from_numpy(self, val, dtype=None, device=None):
dtype = dtype or self.DEFAULT_DTYPE
device = device or self.DEFAULT_DEVICE
Expand Down Expand Up @@ -439,12 +441,13 @@ def matmul_impl(self, x, w, y):

@backend.set_impl(backend.operator_set.gather_nd)
def gather_nd_impl(self, x, w, y, *, batch_dims):
return (f"{y.name} = GatherND<batch_dims={batch_dims}>({x.name}, {w.name})"
return f"{y.name} = GatherND<batch_dims={batch_dims}>({x.name}, {w.name})"
# return (f"{y.name} = GatherND<batch_dims={batch_dims}>({x.name}, {w.name})"
# if w.symval.dtype is dtypes.int64 else
# f"""{w.name}_ = Cast<to={self.onnx_dtype_enum_map[dtypes.int64]}>({w.name})
# {y.name} = GatherND<batch_dims={batch_dims}>({x.name}, {w.name}_)
# """
)
# )


@backend.set_impl(backend.operator_set.scatter_nd)
Expand All @@ -455,12 +458,16 @@ def scatter_nd_impl(
u,
y,
):
return (f"{y.name} = ScatterND({x.name}, {w.name}, {u.name})"
# if w.symval.dtype is dtypes.int64 else
# f"""{w.name}_ = Cast<to={self.onnx_dtype_enum_map[dtypes.int64]}>({w.name})
# {y.name} = ScatterND({x.name}, {w.name}_, {u.name})
return f"{y.name} = ScatterND({x.name}, {w.name}, {u.name})"

# if w.symval.dtype is dtypes.int64:
# return f"{y.name} = ScatterND({x.name}, {w.name}, {u.name})"
# else:
# name = f"{w.name}_{random.randrange(100)}"
# return f"""{name} = Cast<to={self.onnx_dtype_enum_map[dtypes.int64]}>({w.name})
# {y.name} = ScatterND({x.name}, {name}_, {u.name})
# """
)




Expand Down
2 changes: 1 addition & 1 deletion src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class Backend:
LOG_INIT = int(os.environ.get("LOG_INIT", 1))
DEFAULT_DEVICE = devices.name_idx_device_map[os.environ.get("DEFAULT_DEVICE", "cpu:0")]
DEFAULT_DTYPE = dtypes.name_dtype_map[os.environ.get("DEFAULT_DTYPE", "float32")]

dtype_for_indices: DType = None # need to override
def __init__(
self,
operator_set: OperatorSet,
Expand Down
6 changes: 4 additions & 2 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,8 @@ def T(self, cotangents, x, w, *, groups, stride, dilation, padding):
@operator_set.register("gather_nd")
class GatherND(GeneralReduceOperator):
def args_fixer(self, x, w, *, batch_dims: int = 0):
# if w.dtype is not dtypes.int32:
# w = w.cast(dtypes.int32)
if w.dtype is not slope.backend.dtype_for_indices:
w = w.cast(slope.backend.dtype_for_indices)
return (x, w), dict(batch_dims=batch_dims)

def typecheck(self, x, w, *, batch_dims: int):
Expand Down Expand Up @@ -979,6 +979,8 @@ def T(self, cotangents, x, w, *, batch_dims: int):
@operator_set.register("scatter_nd")
class ScatterND(GeneralReduceOperator):
def args_fixer(self, x, w, u):
if w.dtype is not slope.backend.dtype_for_indices:
w = w.cast(slope.backend.dtype_for_indices)
return (x, w, u), dict()

def typecheck(self, x, w, u):
Expand Down

0 comments on commit db2a000

Please sign in to comment.