Skip to content

Commit

Permalink
use tf to translate gather and scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 20, 2024
1 parent e4e2686 commit 62dda37
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 49 deletions.
30 changes: 8 additions & 22 deletions examples/simple/gather.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import slope


# print('#1')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# # w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather(w,1)
# print(f"{y=}")
print('#1')
x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather(w,1)
print(f"{y=}")

# print('\n#2')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
Expand Down Expand Up @@ -52,16 +51,3 @@
# y = x.gather(w)
# breakpoint()
# print(f"{y=}")

#######################

x = slope.ones(8)
print(f"before: {x=}")
w = slope.tensor([[4], [3], [1], [7]], dtype=slope.int32)
u = slope.tensor([9., 10., 11., 12.])
y = slope.scatter(x,w,u)

print(f"{w=}")
print(f"{u=}")
print(f"{x=}")
print(f"{y=}")
63 changes: 63 additions & 0 deletions examples/simple/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import slope

x = slope.zeros(8)
print(f"before: {x=}")
w = slope.tensor([[4], [3], [1], [7]], dtype=slope.int32)
u = slope.tensor([9., 10., 11., 12.])
y = slope.scatter(x,w,u)
# print(f"{w=}")
# print(f"{u=}")
# print(f"{x=}")
print(f"{y=}")

# y_ans = slope.tensor([ 1., 12., 1., 11., 10., 1., 1., 13.], dtype=slope.float32)


# x = slope.tensor([
# [0.0, 0.0, 0.0],
# [0.0, 0.0, 0.0],
# [0.0, 0.0, 0.0],
# ], dtype=slope.float32)
# w = slope.tensor([
# [1, 0, 2],
# [0, 2, 1],
# ], dtype=slope.int32)

# u = slope.tensor([
# [1.0, 1.1, 1.2],
# [2.0, 2.1, 2.2],
# ], dtype=slope.float32)
# # u = u.unsqueeze(-1)
# y = x.scatter(w, u, axis=0)
# print(f"{y=}")
# print(f"{w=}")
# print(f"{u=}")
# print(f"{x=}")

# y_ans = slope.tensor([
# [2.0, 1.1, 0.0],
# [1.0, 0.0, 2.2],
# [0.0, 2.1, 1.2]
# ], dtype=slope.float32)

# print(f"{y_ans=}")

# x = slope.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=slope.float32)
# w = slope.tensor([[1, 3]], dtype=slope.int32)
# u = slope.tensor([[1.1, 2.1]], dtype=slope.float32)
# y = x.scatter(w, u, axis=1)
# y_ans = slope.tensor([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=slope.float32)


# x = slope.tensor(
# [0.0, 0.0, 0.0]
# , dtype=slope.float32)
# w = slope.tensor(
# [[1],[0]]
# , dtype=slope.int32)

# u = slope.tensor(
# [1.0, 2.0],
# dtype=slope.float32)
# y = x.scatter(w, u, axis=0)
# print(f"{y=}")
37 changes: 37 additions & 0 deletions experimental/tf_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import tensorflow as tf
import tensorflow.mlir.experimental as tme

# print('#1')
x = tf.constant([[0.,1.],[2.,3.]], dtype=tf.float32)
w = tf.constant([[1,0],[0,1]], dtype=tf.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = tf.gather_nd(x, w,0)
# print(f"{y=}")

f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w))
print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))
# print('\n#2')
# x = tf.constant([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = tf.constant([[1],[0]]).cast(slope.int64)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather(w)
# print(f"{y=}")

# print('\n#3')
# x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
# w = tf.constant([[0,1],[1,0]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather(w)
# print(f"{y=}")


# print('\n#4')
# x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
# w = tf.constant([[[0,1]],[[1,0]]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather(w)
# print(f"{y=}")
92 changes: 70 additions & 22 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,13 @@ def conv_impl(self, x, w, y, *, groups, stride, dilation, padding):
}
'''

# @backend.set_impl(backend.operator_set.gather)
# def gather_impl(self, x, w, y, *, axis):
# return f"""%{y.name} = "stablehlo.torch_index_select "(%{x.name}, %{w.name}) {{
# batch_dims = 0 : i64, dim = 0 : i64
# # }} {as_mlir_sig((x.symval, w.symval), y.symval)}
# # """

@backend.set_impl(backend.operator_set.gather)
def gather_impl(self, x, w, y, *, axis):
operand_shape = list(x.symval.shape)
Expand Down Expand Up @@ -662,16 +669,6 @@ def gather_impl(self, x, w, y, *, axis):

if (len(collapsed_slice_dims) != len(start_index_map)):
y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device)

# offset_dims = [1]
# index_vector_dim = 1
# start_index_map = [0]
# collapsed_slice_dims = [0,1]
# slice_sizes = [1, 1, 2]
# [[2,3],[4,5]]
# y_pre = SymbolicTensor((2,2,1), y.symval.dtype, y.symval.device)
# y_pre = SymbolicTensor((1,2,2), y.symval.dtype, y.symval.device)


else:
raise ValueError
Expand All @@ -687,6 +684,9 @@ def gather_impl(self, x, w, y, *, axis):
{f'%{y.name} = "stablehlo.reshape"(%{y.name}_) {as_mlir_sig((y_pre,), y.symval)}'
if y_pre is not None else ''}
"""



# <stdin>:3:11: error: start_index_map size (1) is not equal to size of index dimension (1) of start_indices (2)
# <stdin>:3:11: error: slice_sizes size (1) not equal to (implied) operand rank (2)

Expand All @@ -697,21 +697,16 @@ def scatter_impl(self, x, w, u, y, *, axis):
y_init_type = SymbolicTensor((), y.symval.dtype, y.symval.device)
y_mlir_type = as_mlir_shape(y_init_type)

lim = (
(len(w.symval.shape[1 :])) - len(x.symval.shape[:+ 1])
)
r = x.symval.ndim
q = w.symval.ndim
index_vector_dim = q-1
lim = (q-1 - x.symval.shape[0])
lim = None if lim == 0 else lim
update_window_dims = list(x.symval.shape[1 : lim])
inserted_window_dims = [0+axis]
scatter_dims_to_operand_dims = [0+axis]

index_vector_dim = w.symval.ndim-1
inserted_window_dims = [0]
scatter_dims_to_operand_dims = [0]

# TODO: Find cheaper way to copy if exists
return f"""%{x.name}_1 = "stablehlo.constant"(){{
value = dense<{1.0 if "f" in x.symval.dtype.mlir else 1}> : {as_mlir_shape(x.symval)} }} {as_mlir_sig((), x.symval)}
%{x.name}_ = "stablehlo.multiply"(%{x.name}, %{x.name}_1) {as_mlir_sig((x.symval, x.symval), x.symval)}
%{y.name} = "stablehlo.scatter"(%{x.name}_, %{w.name}, %{u.name}) ({{
return f"""%{y.name} = "stablehlo.scatter"(%{x.name}, %{w.name}, %{u.name}) ({{
^bb0(%arg0: {y_mlir_type}, %arg1: {y_mlir_type}):
%0 = "stablehlo.add"(%arg0, %arg1) {as_mlir_sig((y_init_type, y_init_type), y_init_type)}
"stablehlo.return"(%0) : ({y_mlir_type}) -> ()
Expand All @@ -725,3 +720,56 @@ def scatter_impl(self, x, w, u, y, *, axis):
unique_indices = false
}} {as_mlir_sig((x.symval, w.symval, u.symval), y.symval)}
"""

## x = slope.zeros(8,1), x = slope.zeros(8,2)
# update_window_dims = []
# inserted_window_dims = [0,1]
# scatter_dims_to_operand_dims = [0]
# index_vector_dim = 1

# operand_dims = list(range(r))
# update_window_dims = [axis]
# inserted_window_dims = [dim for dim in operand_dims if dim != axis]
# scatter_dims_to_operand_dims = [axis]
# index_vector_dim = axis

# update_window_dims = [1]
# inserted_window_dims = [1]
# scatter_dims_to_operand_dims = [0,1]
# index_vector_dim = 1

# update_window_dims = [1]
# inserted_window_dims = [1]
# scatter_dims_to_operand_dims = [0,1]
# index_vector_dim = 0

'''
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
'''
10 changes: 5 additions & 5 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def vmap(self, dim_size, vals_in, dims_in, **params):

def jvp(self, primals, tangents, *, axis):
(x, w), (x_dot, w_dot) = primals, tangents
return [self(x,w, axis)], [self(x_dot,w, axis)]
return [self(x,w, axis)], [self(x_dot,w_dot, axis)]

def T(self, cotangents, x, w):
(gL_y,) = cotangents
Expand All @@ -980,10 +980,10 @@ def args_fixer(self, x, w, u, *, axis: int = 0):
return (x, w, u), dict(axis=axis)

def typecheck(self, x, w, u, *, axis: int):
assert x.ndim > 0 and w.ndim > 0
assert u.ndim == w.ndim - 1
assert axis < min(x.ndim, w.ndim)
assert 1 <= w.shape[-1] <= x.ndim
# assert x.ndim > 0 and w.ndim > 0
# assert u.ndim == w.ndim - 1
# assert axis < min(x.ndim, w.ndim)
# assert 1 <= w.shape[-1] <= x.ndim
return [x]

def vmap(self, dim_size, vals_in, dims_in, **params):
Expand Down

0 comments on commit 62dda37

Please sign in to comment.