Skip to content

Commit

Permalink
gather_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 21, 2024
1 parent 61295c8 commit 035983e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 46 deletions.
23 changes: 16 additions & 7 deletions examples/simple/gather_nd.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import slope

print('#1')
x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w,0)
print(f"{y=}")
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# # w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w,0)
# print(f"{y=}")

# print('\n#2')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
Expand All @@ -34,6 +34,15 @@
# y = x.gather_nd(w)
# print(f"{y=}")


print('\n#5')
x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
w = slope.tensor([[1],[0]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w,1)
print(f"{y=}")

###############

# x = slope.arange(10, dtype=slope.float32).reshape(2,5)
Expand Down
48 changes: 24 additions & 24 deletions experimental/tf_gather_nd.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import tensorflow as tf
import tensorflow.mlir.experimental as tme

print('\n#1')
x = tf.constant([[0.,1.],[2.,3.]], dtype=tf.float32);print(f"{x=}")
w = tf.constant([[1,0],[0,1]], dtype=tf.int32);print(f"{w=}")
y = tf.gather_nd(x, w,0); print(f"{y=}")
f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w, 0))
print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))
# print('\n#1')
# x = tf.constant([[0.,1.],[2.,3.]], dtype=tf.float32);print(f"{x=}")
# w = tf.constant([[1,0],[0,1]], dtype=tf.int32);print(f"{w=}")
# y = tf.gather_nd(x, w,0); print(f"{y=}")
# f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w, 0))
# print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))


print('\n#2')
x = tf.constant([[0.,1.],[2.,3.]], dtype=tf.float32);print(f"{x=}")
w = tf.constant([[1],[0]], dtype=tf.int32);print(f"{w=}")
y = tf.gather_nd(x, w, 0); print(f"{y=}")
f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w,0 ))
print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))
# print('\n#2')
# x = tf.constant([[0.,1.],[2.,3.]], dtype=tf.float32);print(f"{x=}")
# w = tf.constant([[1],[0]], dtype=tf.int32);print(f"{w=}")
# y = tf.gather_nd(x, w, 0); print(f"{y=}")
# f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w,0 ))
# print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))

print('\n#3')
x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=tf.float32); print(f"{x=}")
w = tf.constant([[0,1],[1,0]], dtype=tf.int32); print(f"{w=}")
y = tf.gather_nd(x, w, 0); print(f"{y=}")
f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w, 0))
print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))
# print('\n#3')
# x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=tf.float32); print(f"{x=}")
# w = tf.constant([[0,1],[1,0]], dtype=tf.int32); print(f"{w=}")
# y = tf.gather_nd(x, w, 0); print(f"{y=}")
# f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w, 0))
# print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))


print('\n#4')
x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=tf.float32); print(f"{x=}")
w = tf.constant([[0,1],[1,0]], dtype=tf.int32); print(f"{w=}")
y = tf.gather_nd(x, w, 0); print(f"{y=}")
f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w, 0))
print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))
# print('\n#4')
# x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=tf.float32); print(f"{x=}")
# w = tf.constant([[0,1],[1,0]], dtype=tf.int32); print(f"{w=}")
# y = tf.gather_nd(x, w, 0); print(f"{y=}")
# f = tme.convert_function((tf.function(tf.gather_nd)).get_concrete_function(x, w, 0))
# print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))

print('\n#5')
x = tf.constant([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=tf.float32); print(f"{x=}")
Expand Down
21 changes: 8 additions & 13 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,6 @@ def conv_impl(self, x, w, y, *, groups, stride, dilation, padding):

@backend.set_impl(backend.operator_set.gather_nd)
def gather_nd_impl(self, x, w, y, batch_dims):
if batch_dims > 0:
raise NotImplementedError
operand_shape = list(x.symval.shape)
indices_shape = list(w.symval.shape)
r = x.symval.ndim
Expand All @@ -607,10 +605,9 @@ def gather_nd_impl(self, x, w, y, batch_dims):
y_reshape = None
w_arange = None
if b > 0:
w_arange = SymbolicTensor(tuple(range(operand_shape[i]) for i in range(b)),
w.symval.dtype,
w.symval.device
)
w_arange = w.symval
# breakpoint()
# w_arange = SymbolicTensor(w.symval.shape, w.symval.dtype, w.symval.device)

if indices_shape[-1] == r:
slice_sizes = [1]*r
Expand All @@ -623,7 +620,7 @@ def gather_nd_impl(self, x, w, y, batch_dims):
y_reshape = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device)
elif indices_shape[-1] < r:
slice_sizes = [*[1]*(r-1), *operand_shape[-1:]]
start_index_map = [i+b for i, s in enumerate(slice_sizes) if s==1 and i < q]
start_index_map = [i for i, s in enumerate(slice_sizes) if s==1 and i < q]

collapsed_slice_dims = []
for i in range(len(slice_sizes)):
Expand All @@ -635,15 +632,13 @@ def gather_nd_impl(self, x, w, y, batch_dims):

else:
raise ValueError
w_symval = w.symval if w_arange is None else SymbolicTensor(w_arange.shape + w.symval.shape,
w.symval.dtype,
w.symval.device
)
w_symval = w.symval if w_arange is None else SymbolicTensor(
tuple(d+b if i == b else d for i, d in enumerate(w.symval.shape)), w.symval.dtype, w.symval.device)
y_symval = y.symval if y_reshape is None else y_reshape
y_affix = "" if y_reshape is None else "_"
w_affix = "" if w_arange is None else "_"
return f"""{f'''%{w.name}_i = "stablehlo.iota"() {{ iota_dimension = 0 : i64}} {as_mlir_sig((), w_arange)}
%{w.name}_ = stablehlo.concatenate {w.name}_i, {w.name} dim = {b} {as_mlir_sig((w_arange, w.symval), w_symval)} '''
%{w.name}_ = "stablehlo.concatenate"(%{w.name}_i, %{w.name}) {{ dimension = {b} : i64}} {as_mlir_sig((w_arange, w.symval), w_symval)} '''
if w_arange is not None else ''}
%{y.name}{y_affix} = "stablehlo.gather"(%{x.name}, %{w.name}{w_affix}) {{
dimension_numbers = #stablehlo.gather<
Expand All @@ -653,7 +648,7 @@ def gather_nd_impl(self, x, w, y, batch_dims):
index_vector_dim = {index_vector_dim}>,
slice_sizes = dense<{slice_sizes}> : tensor<{len(slice_sizes)}xi64>,
indices_are_sorted = false
}} {as_mlir_sig((x.symval, w.symval), y_symval)}
}} {as_mlir_sig((x.symval, w_symval), y_symval)}
{f'%{y.name} = "stablehlo.reshape"(%{y.name}_) {as_mlir_sig((y_symval,), y.symval)}'
if y_reshape is not None else ''}
"""
Expand Down
8 changes: 6 additions & 2 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,11 +937,15 @@ def args_fixer(self, x, w, *, batch_dims: int = 0):
def typecheck(self, x, w, *, batch_dims: int):
r = x.ndim
q = w.ndim
b = batch_dims
assert r > 0 and q > 0
assert 1 <= w.shape[-1] <= r
assert w.shape[-1] <= r
assert batch_dims < min(x.ndim, w.ndim)
shape = w.shape[: q - 1] + x.shape[w.shape[-1] :]
assert b < min(x.ndim, w.ndim)
# shape = w.shape[: q - 1] + x.shape[w.shape[-1] :]
bx = x.shape[b:]
bw = w.shape[b:]
shape = bx[:b] + bw[:len(bw) - 1] + bx[bw[-1]:]
return [SymbolicTensor(shape, x.dtype, x.device)]

def vmap(self, dim_size, vals_in, dims_in, **params):
Expand Down

0 comments on commit 035983e

Please sign in to comment.