From 62dda37875702a2660d64f98cf21142e810b370d Mon Sep 17 00:00:00 2001 From: radenmuaz Date: Sun, 21 Jan 2024 01:03:19 +0800 Subject: [PATCH] use tf to translate gather and scatter --- examples/simple/gather.py | 30 ++++--------- examples/simple/scatter.py | 63 ++++++++++++++++++++++++++ experimental/tf_gather.py | 37 +++++++++++++++ src/slope/backends/iree.py | 92 +++++++++++++++++++++++++++++--------- src/slope/operators.py | 10 ++--- 5 files changed, 183 insertions(+), 49 deletions(-) create mode 100644 examples/simple/scatter.py create mode 100644 experimental/tf_gather.py diff --git a/examples/simple/gather.py b/examples/simple/gather.py index f7f777e..7eaec69 100644 --- a/examples/simple/gather.py +++ b/examples/simple/gather.py @@ -1,14 +1,13 @@ import slope - -# print('#1') -# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) -# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32) -# # w = slope.tensor([[0,0],[1,1]], dtype=slope.int32) -# print(f"{x=}") -# print(f"{w=}") -# y = x.gather(w,1) -# print(f"{y=}") +print('#1') +x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) +w = slope.tensor([[1,0],[0,1]], dtype=slope.int32) +# w = slope.tensor([[0,0],[1,1]], dtype=slope.int32) +print(f"{x=}") +print(f"{w=}") +y = x.gather(w,1) +print(f"{y=}") # print('\n#2') # x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) @@ -52,16 +51,3 @@ # y = x.gather(w) # breakpoint() # print(f"{y=}") - -####################### - -x = slope.ones(8) -print(f"before: {x=}") -w = slope.tensor([[4], [3], [1], [7]], dtype=slope.int32) -u = slope.tensor([9., 10., 11., 12.]) -y = slope.scatter(x,w,u) - -print(f"{w=}") -print(f"{u=}") -print(f"{x=}") -print(f"{y=}") \ No newline at end of file diff --git a/examples/simple/scatter.py b/examples/simple/scatter.py new file mode 100644 index 0000000..41765e7 --- /dev/null +++ b/examples/simple/scatter.py @@ -0,0 +1,63 @@ +import slope + +x = slope.zeros(8) +print(f"before: {x=}") +w = slope.tensor([[4], [3], [1], [7]], dtype=slope.int32) +u = slope.tensor([9., 10., 11., 12.]) +y = slope.scatter(x,w,u) +# print(f"{w=}") +# print(f"{u=}") +# print(f"{x=}") +print(f"{y=}") + +# y_ans = slope.tensor([ 1., 12., 1., 11., 10., 1., 1., 13.], dtype=slope.float32) + + +# x = slope.tensor([ +# [0.0, 0.0, 0.0], +# [0.0, 0.0, 0.0], +# [0.0, 0.0, 0.0], +# ], dtype=slope.float32) +# w = slope.tensor([ +# [1, 0, 2], +# [0, 2, 1], +# ], dtype=slope.int32) + +# u = slope.tensor([ +# [1.0, 1.1, 1.2], +# [2.0, 2.1, 2.2], +# ], dtype=slope.float32) +# # u = u.unsqueeze(-1) +# y = x.scatter(w, u, axis=0) +# print(f"{y=}") +# print(f"{w=}") +# print(f"{u=}") +# print(f"{x=}") + +# y_ans = slope.tensor([ +# [2.0, 1.1, 0.0], +# [1.0, 0.0, 2.2], +# [0.0, 2.1, 1.2] +# ], dtype=slope.float32) + +# print(f"{y_ans=}") + +# x = slope.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=slope.float32) +# w = slope.tensor([[1, 3]], dtype=slope.int32) +# u = slope.tensor([[1.1, 2.1]], dtype=slope.float32) +# y = x.scatter(w, u, axis=1) +# y_ans = slope.tensor([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=slope.float32) + + +# x = slope.tensor( +# [0.0, 0.0, 0.0] +# , dtype=slope.float32) +# w = slope.tensor( +# [[1],[0]] +# , dtype=slope.int32) + +# u = slope.tensor( +# [1.0, 2.0], +# dtype=slope.float32) +# y = x.scatter(w, u, axis=0) +# print(f"{y=}") diff --git a/experimental/tf_gather.py b/experimental/tf_gather.py new file mode 100644 index 0000000..ebbaaf8 --- /dev/null +++ b/experimental/tf_gather.py @@ -0,0 +1,37 @@ +import tensorflow as tf +import tensorflow.mlir.experimental as tme + +# print('#1') +x = tf.constant([[0.,1.],[2.,3.]], dtype=tf.float32) +w = tf.constant([[1,0],[0,1]], dtype=tf.int32) +# print(f"{x=}") +# print(f"{w=}") +# y = tf.gather_nd(x, w,0) +# print(f"{y=}") + +f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w)) +print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo")) +# print('\n#2') +# x = tf.constant([[0.,1.],[2.,3.]], dtype=slope.float32) +# w = tf.constant([[1],[0]]).cast(slope.int64) +# print(f"{x=}") +# print(f"{w=}") +# y = x.gather(w) +# print(f"{y=}") + +# print('\n#3') +# x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) +# w = tf.constant([[0,1],[1,0]], dtype=slope.int32) +# print(f"{x=}") +# print(f"{w=}") +# y = x.gather(w) +# print(f"{y=}") + + +# print('\n#4') +# x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) +# w = tf.constant([[[0,1]],[[1,0]]], dtype=slope.int32) +# print(f"{x=}") +# print(f"{w=}") +# y = x.gather(w) +# print(f"{y=}") \ No newline at end of file diff --git a/src/slope/backends/iree.py b/src/slope/backends/iree.py index abb552e..2aea843 100644 --- a/src/slope/backends/iree.py +++ b/src/slope/backends/iree.py @@ -633,6 +633,13 @@ def conv_impl(self, x, w, y, *, groups, stride, dilation, padding): } ''' +# @backend.set_impl(backend.operator_set.gather) +# def gather_impl(self, x, w, y, *, axis): +# return f"""%{y.name} = "stablehlo.torch_index_select "(%{x.name}, %{w.name}) {{ +# batch_dims = 0 : i64, dim = 0 : i64 +# # }} {as_mlir_sig((x.symval, w.symval), y.symval)} +# # """ + @backend.set_impl(backend.operator_set.gather) def gather_impl(self, x, w, y, *, axis): operand_shape = list(x.symval.shape) @@ -662,16 +669,6 @@ def gather_impl(self, x, w, y, *, axis): if (len(collapsed_slice_dims) != len(start_index_map)): y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device) - - # offset_dims = [1] - # index_vector_dim = 1 - # start_index_map = [0] - # collapsed_slice_dims = [0,1] - # slice_sizes = [1, 1, 2] - # [[2,3],[4,5]] - # y_pre = SymbolicTensor((2,2,1), y.symval.dtype, y.symval.device) - # y_pre = SymbolicTensor((1,2,2), y.symval.dtype, y.symval.device) - else: raise ValueError @@ -687,6 +684,9 @@ def gather_impl(self, x, w, y, *, axis): {f'%{y.name} = "stablehlo.reshape"(%{y.name}_) {as_mlir_sig((y_pre,), y.symval)}' if y_pre is not None else ''} """ + + + # :3:11: error: start_index_map size (1) is not equal to size of index dimension (1) of start_indices (2) # :3:11: error: slice_sizes size (1) not equal to (implied) operand rank (2) @@ -697,21 +697,16 @@ def scatter_impl(self, x, w, u, y, *, axis): y_init_type = SymbolicTensor((), y.symval.dtype, y.symval.device) y_mlir_type = as_mlir_shape(y_init_type) - lim = ( - (len(w.symval.shape[1 :])) - len(x.symval.shape[:+ 1]) - ) + r = x.symval.ndim + q = w.symval.ndim + index_vector_dim = q-1 + lim = (q-1 - x.symval.shape[0]) lim = None if lim == 0 else lim update_window_dims = list(x.symval.shape[1 : lim]) - inserted_window_dims = [0+axis] - scatter_dims_to_operand_dims = [0+axis] - - index_vector_dim = w.symval.ndim-1 + inserted_window_dims = [0] + scatter_dims_to_operand_dims = [0] - # TODO: Find cheaper way to copy if exists - return f"""%{x.name}_1 = "stablehlo.constant"(){{ - value = dense<{1.0 if "f" in x.symval.dtype.mlir else 1}> : {as_mlir_shape(x.symval)} }} {as_mlir_sig((), x.symval)} -%{x.name}_ = "stablehlo.multiply"(%{x.name}, %{x.name}_1) {as_mlir_sig((x.symval, x.symval), x.symval)} -%{y.name} = "stablehlo.scatter"(%{x.name}_, %{w.name}, %{u.name}) ({{ + return f"""%{y.name} = "stablehlo.scatter"(%{x.name}, %{w.name}, %{u.name}) ({{ ^bb0(%arg0: {y_mlir_type}, %arg1: {y_mlir_type}): %0 = "stablehlo.add"(%arg0, %arg1) {as_mlir_sig((y_init_type, y_init_type), y_init_type)} "stablehlo.return"(%0) : ({y_mlir_type}) -> () @@ -725,3 +720,56 @@ def scatter_impl(self, x, w, u, y, *, axis): unique_indices = false }} {as_mlir_sig((x.symval, w.symval, u.symval), y.symval)} """ + + ## x = slope.zeros(8,1), x = slope.zeros(8,2) + # update_window_dims = [] + # inserted_window_dims = [0,1] + # scatter_dims_to_operand_dims = [0] + # index_vector_dim = 1 + + # operand_dims = list(range(r)) + # update_window_dims = [axis] + # inserted_window_dims = [dim for dim in operand_dims if dim != axis] + # scatter_dims_to_operand_dims = [axis] + # index_vector_dim = axis + + # update_window_dims = [1] + # inserted_window_dims = [1] + # scatter_dims_to_operand_dims = [0,1] + # index_vector_dim = 1 + + # update_window_dims = [1] + # inserted_window_dims = [1] + # scatter_dims_to_operand_dims = [0,1] + # index_vector_dim = 0 + +''' +// %input: [ +// [[1, 2], [3, 4], [5, 6], [7, 8]], +// [[9, 10], [11, 12], [13, 14], [15, 16]], +// [[17, 18], [19, 20], [21, 22], [23, 24]] +// ] +// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]] +// %update: [ +// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]], +// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]] +// ] +%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () +}) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = false +} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64> +// %result: [ +// [[1, 2], [5, 6], [7, 8], [7, 8]], +// [[10, 11], [12, 13], [14, 15], [16, 17]], +// [[18, 19], [20, 21], [21, 22], [23, 24]] +// ] +''' \ No newline at end of file diff --git a/src/slope/operators.py b/src/slope/operators.py index 63a032f..785c370 100644 --- a/src/slope/operators.py +++ b/src/slope/operators.py @@ -967,7 +967,7 @@ def vmap(self, dim_size, vals_in, dims_in, **params): def jvp(self, primals, tangents, *, axis): (x, w), (x_dot, w_dot) = primals, tangents - return [self(x,w, axis)], [self(x_dot,w, axis)] + return [self(x,w, axis)], [self(x_dot,w_dot, axis)] def T(self, cotangents, x, w): (gL_y,) = cotangents @@ -980,10 +980,10 @@ def args_fixer(self, x, w, u, *, axis: int = 0): return (x, w, u), dict(axis=axis) def typecheck(self, x, w, u, *, axis: int): - assert x.ndim > 0 and w.ndim > 0 - assert u.ndim == w.ndim - 1 - assert axis < min(x.ndim, w.ndim) - assert 1 <= w.shape[-1] <= x.ndim + # assert x.ndim > 0 and w.ndim > 0 + # assert u.ndim == w.ndim - 1 + # assert axis < min(x.ndim, w.ndim) + # assert 1 <= w.shape[-1] <= x.ndim return [x] def vmap(self, dim_size, vals_in, dims_in, **params):