Skip to content

Commit

Permalink
gather
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 19, 2024
1 parent c31d67d commit e4e2686
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 139 deletions.
116 changes: 17 additions & 99 deletions examples/simple/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,27 @@

# print('#1')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# # w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# y = x.gather(w,1)
# print(f"{y=}")

# print('\n#2')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = slope.tensor([[1],[0]]).cast(slope.int64)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# y = x.gather(w)
# print(f"{y=}")

# print('\n#3')
# x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
# w = slope.tensor([[0,1],[1,0]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# y = x.gather(w)
# print(f"{y=}")


Expand All @@ -31,98 +32,15 @@
# w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# y = x.gather(w)
# print(f"{y=}")

print('\n#5')
x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
w = slope.tensor([[1],[0]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w, 1)
print(f"{y=}")

'''
func.func @main (%x0: tensor<2x2xf32>, %x1: tensor<2x2xi32>) -> (tensor<2xf32>)
{
%y0_ = "stablehlo.gather"(%x0, %x1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [1],
collapsed_slice_dims = [0],
start_index_map = [0, 1],
index_vector_dim = 1>,
slice_sizes = dense<[1, 1]> : tensor<2xi64>,
indices_are_sorted = false
} : (tensor<2x2xf32>,tensor<2x2xi32>) -> tensor<2x1xf32>
%y0 = "stablehlo.reshape"(%y0_) : (tensor<2x1xf32>) -> tensor<2xf32>
"func.return"(%y0): (tensor<2xf32>) -> ()
}
'''
'''
func.func @main (%x0: tensor<2x2xf32>, %x1: tensor<2x1xi32>) -> (tensor<2x2xf32>)
{
%y0 = "stablehlo.gather"(%x0, %x1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [1],
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1>,
slice_sizes = dense<[1, 2]> : tensor<2xi64>,
indices_are_sorted = false
} : (tensor<2x2xf32>,tensor<2x1xi32>) -> tensor<2x2xf32>
"func.return"(%y0): (tensor<2x2xf32>) -> ()
}
'''

'''
func.func @main (%x0: tensor<2x2x2xf32>, %x1: tensor<2x2xi32>) -> (tensor<2x2xf32>)
{
%y0 = "stablehlo.gather"(%x0, %x1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [1],
collapsed_slice_dims = [0, 1],
start_index_map = [0, 1],
index_vector_dim = 1>,
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<2x2x2xf32>,tensor<2x2xi32>) -> tensor<2x2xf32>
"func.return"(%y0): (tensor<2x2xf32>) -> ()
}
'''


'''
<stdin>:3:11: error: start_index_map size (2)
is not equal to size of index dimension (1) of start_indices (1)
func.func @main (%x0: tensor<2x2x2xf32>, %x1: tensor<2x1x2xi32>) -> (tensor<2x2x2xf32>)
{
%y0 = "stablehlo.gather"(%x0, %x1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [1, 2],
collapsed_slice_dims = [0, 1],
start_index_map = [0, 1],
index_vector_dim = 1>,
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<2x2x2xf32>,tensor<2x1x2xi32>) -> tensor<2x2x2xf32>
"func.return"(%y0): (tensor<2x2x2xf32>) -> ()
}
'''


###############

# x = slope.arange(10, dtype=slope.float32).reshape(2,5)
# w = slope.tensor([1,0])[..., None]
# w = w.cast(slope.int64)
# y = x.gather_nd(w)
# y = x.gather(w)
# print(f"{x=}")
# print(f"{w=}")
# print(f"{y=}")
Expand All @@ -131,19 +49,19 @@
# w = x
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# y = x.gather(w)
# breakpoint()
# print(f"{y=}")

#######################

# x = slope.ones(8)
# print(f"before: {x=}")
# w = slope.tensor([[4], [3], [1], [7]], dtype=slope.int32)
# u = slope.tensor([9., 10., 11., 12.])
# y = slope.scatter_nd(x,w,u)
x = slope.ones(8)
print(f"before: {x=}")
w = slope.tensor([[4], [3], [1], [7]], dtype=slope.int32)
u = slope.tensor([9., 10., 11., 12.])
y = slope.scatter(x,w,u)

# print(f"{w=}")
# print(f"{u=}")
# print(f"{x=}")
# print(f"{y=}")
print(f"{w=}")
print(f"{u=}")
print(f"{x=}")
print(f"{y=}")
45 changes: 28 additions & 17 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,27 +643,34 @@ def gather_impl(self, x, w, y, *, axis):
index_vector_dim = q-1
y_pre = None
if indices_shape[-1] == r:
# Each scalar value corresponding to data[0:b-1,indices_slice]
slice_sizes = [1]*r
start_index_map = list(range(q))
collapsed_slice_dims = [i for i in range(b+1,len(slice_sizes))
start_index_map = [i for i in range(q)]
if axis != 0:
start_index_map = start_index_map[axis:] + start_index_map[:axis]
collapsed_slice_dims = [i for i in range(1-axis,len(slice_sizes)-axis)
if slice_sizes[i] == 1 and operand_shape[i] == indices_shape[i]]

y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device)

elif indices_shape[-1] < r:
# Each tensor slice corresponding to data[0:b-1, indices_slice , :]

slice_sizes = [*[1]*(r-1), *operand_shape[-1:]]
start_index_map = [i for i, s in enumerate(slice_sizes) if s==1 and i < q-b]
start_index_map = [i+axis for i, s in enumerate(slice_sizes) if s==1 and i < q]

collapsed_slice_dims = []
for i in range(len(slice_sizes)):
if slice_sizes[i] == 1 and len(offset_dims)+len(collapsed_slice_dims) != r:
collapsed_slice_dims += [i]

if (len(collapsed_slice_dims) != len(start_index_map)) and b==0:
if (len(collapsed_slice_dims) != len(start_index_map)):
y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device)

# offset_dims = [1]
# index_vector_dim = 1
# start_index_map = [0]
# collapsed_slice_dims = [0,1]
# slice_sizes = [1, 1, 2]
# [[2,3],[4,5]]
# y_pre = SymbolicTensor((2,2,1), y.symval.dtype, y.symval.device)
# y_pre = SymbolicTensor((1,2,2), y.symval.dtype, y.symval.device)


else:
Expand All @@ -685,20 +692,24 @@ def gather_impl(self, x, w, y, *, axis):



@backend.set_impl(backend.operator_set.scatter_nd)
def scatter_nd_impl(self, x, w, u, y, *, batch_dims):
@backend.set_impl(backend.operator_set.scatter)
def scatter_impl(self, x, w, u, y, *, axis):
y_init_type = SymbolicTensor((), y.symval.dtype, y.symval.device)
y_mlir_type = as_mlir_shape(y_init_type)

lim = (
batch_dims + (len(w.symval.shape[batch_dims + 1 :])) - len(x.symval.shape[: batch_dims + 1])
(len(w.symval.shape[1 :])) - len(x.symval.shape[:+ 1])
)
lim = None if lim == 0 else lim
update_window_dims = list(x.symval.shape[(batch_dims + 1) : lim])
inserted_window_dims = [0]
scatter_dims_to_operand_dims = [0]
one = 1.0 if "f" in x.symval.dtype.mlir else 1
update_window_dims = list(x.symval.shape[1 : lim])
inserted_window_dims = [0+axis]
scatter_dims_to_operand_dims = [0+axis]

index_vector_dim = w.symval.ndim-1

# TODO: Find cheaper way to copy if exists
return f"""%{x.name}_1 = "stablehlo.constant"(){{ value = dense<{one}> : {as_mlir_shape(x.symval)} }} {as_mlir_sig((), x.symval)}
return f"""%{x.name}_1 = "stablehlo.constant"(){{
value = dense<{1.0 if "f" in x.symval.dtype.mlir else 1}> : {as_mlir_shape(x.symval)} }} {as_mlir_sig((), x.symval)}
%{x.name}_ = "stablehlo.multiply"(%{x.name}, %{x.name}_1) {as_mlir_sig((x.symval, x.symval), x.symval)}
%{y.name} = "stablehlo.scatter"(%{x.name}_, %{w.name}, %{u.name}) ({{
^bb0(%arg0: {y_mlir_type}, %arg1: {y_mlir_type}):
Expand All @@ -709,7 +720,7 @@ def scatter_nd_impl(self, x, w, u, y, *, batch_dims):
update_window_dims = {update_window_dims},
inserted_window_dims = {inserted_window_dims},
scatter_dims_to_operand_dims = {scatter_dims_to_operand_dims},
index_vector_dim = {batch_dims+1}>,
index_vector_dim = {index_vector_dim}>,
indices_are_sorted = false,
unique_indices = false
}} {as_mlir_sig((x.symval, w.symval, u.symval), y.symval)}
Expand Down
40 changes: 17 additions & 23 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ def T(self, cotangents, x, w, *, groups, stride, dilation, padding):
return [None, gL_w]



@operator_set.register("gather")
class Gather(GeneralReduceOperator):
def args_fixer(self, x, w, *, axis: int = 0):
Expand All @@ -950,20 +951,11 @@ def args_fixer(self, x, w, *, axis: int = 0):
def typecheck(self, x, w, *, axis: int):
r = x.ndim
q = w.ndim
b = axis
assert r > 0 and q > 0
assert x.shape[:b] == w.shape[:b]
assert b < min(q, r)
assert 1 <= w.shape[-1] <= r - b
assert w.shape[-1] <= r - b
# output_shape = [w.shape[b:][i] for i in range(len(w.shape[b:]) - 1)]
# output_shape.extend(x.shape[b:][w.shape[b:][-1]:])
# shape = x.shape[b:][:b] + tuple(output_shape)
bx = x.shape[b:]
bw = w.shape[b:]
shape = bx[:b] + bw[:len(bw) - 1] + bx[bw[-1]:]


assert 1 <= w.shape[-1] <= r
assert w.shape[-1] <= r
assert axis < min(x.ndim, w.ndim)
shape = w.shape[:q - 1] + x.shape[w.shape[-1]:]
return [SymbolicTensor(shape, x.dtype, x.device)]


Expand All @@ -979,21 +971,19 @@ def jvp(self, primals, tangents, *, axis):

def T(self, cotangents, x, w):
(gL_y,) = cotangents
return [x.scatter_nd(w, gL_y)
, None]
return [x.scatter(w, gL_y), None]


@operator_set.register("scatter_nd")
class ScatterND(GeneralReduceOperator):
@operator_set.register("scatter")
class Scatter(GeneralReduceOperator):
def args_fixer(self, x, w, u, *, axis: int = 0):
return (x, w, u), dict(axis=axis)

def typecheck(self, x, w, u, *, axis: int):
assert x.ndim > 0 and w.ndim > 0
assert u.ndim == w.ndim - 1
assert x.shape[:axis] == w.shape[:axis]
assert axis < min(x.ndim, w.ndim)
assert 1 <= w.shape[-1] <= x.ndim - axis
assert 1 <= w.shape[-1] <= x.ndim
return [x]

def vmap(self, dim_size, vals_in, dims_in, **params):
Expand All @@ -1006,10 +996,14 @@ def jvp(self, primals, tangents):
(x, w), (x_dot, w_dot) = primals, tangents
return [x @ w], [(x_dot @ w) + (x @ w_dot)]

def T(self, cotangents, x, w):
def jvp(self, primals, tangents, *, axis):
(x, w, u), (x_dot, w_dot, u_dot) = primals, tangents
return [self(x,w,u, axis)], [self(x_dot,w_dot, u_dot, axis)]

def T(self, cotangents, x, w, u):
assert (type(x) is UndefPrimal) ^ (type(w) is UndefPrimal) ^ (type(u) is UndefPrimal)
(gL_y,) = cotangents
assert (type(x) is UndefPrimal) ^ (type(w) is UndefPrimal)
if type(x) is UndefPrimal:
return [gL_y @ w.transpose(-1, -2), None]
return [gL_y.gather(u), None]
elif type(w) is UndefPrimal:
return [None, x.transpose(-1, -2) @ gL_y]
return [self, None]

0 comments on commit e4e2686

Please sign in to comment.