diff --git a/examples/simple/gather_nd.py b/examples/simple/gather_nd.py index a24c051..c931f83 100644 --- a/examples/simple/gather_nd.py +++ b/examples/simple/gather_nd.py @@ -2,12 +2,12 @@ print('#1') x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) -w = slope.tensor([[1,0],[0,1]], dtype=slope.int64) -# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32) +# w = slope.tensor([[1,0],[0,1]], dtype=slope.int64) +w = slope.tensor([[1,0],[0,1]], dtype=slope.int32) print(f"{x=}") print(f"{w=}") -# y = x.gather_nd(w,0) -# print(f"{y=}") +y = x.gather_nd(w,0) +print(f"{y=}") @slope.jit def f(x, w): diff --git a/src/slope/backends/iree.py b/src/slope/backends/iree.py index c26c35c..41b2969 100644 --- a/src/slope/backends/iree.py +++ b/src/slope/backends/iree.py @@ -55,6 +55,7 @@ def annotate_sig(in_symvals, out_symvals): class IREEBackend(Backend): + dtype_for_indices = dtypes.int64 dtype_map = { dtypes.float32: np.dtypes.Float32DType(), dtypes.uint8: np.dtypes.UInt8DType(), diff --git a/src/slope/backends/onnxruntime.py b/src/slope/backends/onnxruntime.py index e532043..fc33c20 100644 --- a/src/slope/backends/onnxruntime.py +++ b/src/slope/backends/onnxruntime.py @@ -33,7 +33,7 @@ import onnx import onnxruntime import tempfile - +import random def annotate_shape(symval): xdtype = symval.dtype.mlir @@ -56,16 +56,7 @@ def annotate_sig(in_symvals, out_symvals): class ONNXRuntimeBackend(Backend): - sess_options = onnxruntime.SessionOptions() - # Disable this flags, easily get nan - sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL - - # Other flags - # sess_options.log_severity_level = 3 - # sess_options.use_deterministic_compute = True - # sess_options.intra_op_num_threads = 4 - # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + dtype_for_indices = dtypes.int64 dtype_map = { slope.core.dtypes.float32: "float", dtypes.uint8: "uint8", @@ -100,6 +91,17 @@ class ONNXRuntimeBackend(Backend): dtype_map_inv = {v: k for k, v in dtype_map.items()} device_map_inv = {v: k for k, v in device_map.items()} + sess_options = onnxruntime.SessionOptions() + # Disable this flags, easily get nan + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL + # Other flags + # sess_options.log_severity_level = 3 + # sess_options.use_deterministic_compute = True + # sess_options.intra_op_num_threads = 4 + # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED + # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + + def from_numpy(self, val, dtype=None, device=None): dtype = dtype or self.DEFAULT_DTYPE device = device or self.DEFAULT_DEVICE @@ -439,12 +441,13 @@ def matmul_impl(self, x, w, y): @backend.set_impl(backend.operator_set.gather_nd) def gather_nd_impl(self, x, w, y, *, batch_dims): - return (f"{y.name} = GatherND({x.name}, {w.name})" + return f"{y.name} = GatherND({x.name}, {w.name})" +# return (f"{y.name} = GatherND({x.name}, {w.name})" # if w.symval.dtype is dtypes.int64 else # f"""{w.name}_ = Cast({w.name}) # {y.name} = GatherND({x.name}, {w.name}_) # """ -) +# ) @backend.set_impl(backend.operator_set.scatter_nd) @@ -455,12 +458,16 @@ def scatter_nd_impl( u, y, ): - return (f"{y.name} = ScatterND({x.name}, {w.name}, {u.name})" -# if w.symval.dtype is dtypes.int64 else -# f"""{w.name}_ = Cast({w.name}) -# {y.name} = ScatterND({x.name}, {w.name}_, {u.name}) + return f"{y.name} = ScatterND({x.name}, {w.name}, {u.name})" + +# if w.symval.dtype is dtypes.int64: +# return f"{y.name} = ScatterND({x.name}, {w.name}, {u.name})" +# else: +# name = f"{w.name}_{random.randrange(100)}" +# return f"""{name} = Cast({w.name}) +# {y.name} = ScatterND({x.name}, {name}_, {u.name}) # """ - ) + diff --git a/src/slope/core.py b/src/slope/core.py index c13ee85..e5f5fb5 100644 --- a/src/slope/core.py +++ b/src/slope/core.py @@ -711,7 +711,7 @@ class Backend: LOG_INIT = int(os.environ.get("LOG_INIT", 1)) DEFAULT_DEVICE = devices.name_idx_device_map[os.environ.get("DEFAULT_DEVICE", "cpu:0")] DEFAULT_DTYPE = dtypes.name_dtype_map[os.environ.get("DEFAULT_DTYPE", "float32")] - + dtype_for_indices: DType = None # need to override def __init__( self, operator_set: OperatorSet, diff --git a/src/slope/operators.py b/src/slope/operators.py index ed7f227..17ea3e2 100644 --- a/src/slope/operators.py +++ b/src/slope/operators.py @@ -933,8 +933,8 @@ def T(self, cotangents, x, w, *, groups, stride, dilation, padding): @operator_set.register("gather_nd") class GatherND(GeneralReduceOperator): def args_fixer(self, x, w, *, batch_dims: int = 0): - # if w.dtype is not dtypes.int32: - # w = w.cast(dtypes.int32) + if w.dtype is not slope.backend.dtype_for_indices: + w = w.cast(slope.backend.dtype_for_indices) return (x, w), dict(batch_dims=batch_dims) def typecheck(self, x, w, *, batch_dims: int): @@ -979,6 +979,8 @@ def T(self, cotangents, x, w, *, batch_dims: int): @operator_set.register("scatter_nd") class ScatterND(GeneralReduceOperator): def args_fixer(self, x, w, u): + if w.dtype is not slope.backend.dtype_for_indices: + w = w.cast(slope.backend.dtype_for_indices) return (x, w, u), dict() def typecheck(self, x, w, u):