Skip to content

Commit

Permalink
gather_nd changed to gather
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 19, 2024
1 parent a5f9e50 commit c31d67d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 96 deletions.
22 changes: 11 additions & 11 deletions examples/simple/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@
# print(f"{y=}")


print('\n#4')
x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w)
print(f"{y=}")

# print('\n#5')
# print('\n#4')
# x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
# w = slope.tensor([[1],[0]], dtype=slope.int32)
# w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w, 1)
# y = x.gather_nd(w)
# print(f"{y=}")

print('\n#5')
x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
w = slope.tensor([[1],[0]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w, 1)
print(f"{y=}")

'''
func.func @main (%x0: tensor<2x2xf32>, %x1: tensor<2x2xi32>) -> (tensor<2xf32>)
Expand Down
23 changes: 8 additions & 15 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,23 +633,16 @@ def conv_impl(self, x, w, y, *, groups, stride, dilation, padding):
}
'''

@backend.set_impl(backend.operator_set.gather_nd)
def gather_nd_impl(self, x, w, y, *, batch_dims):
@backend.set_impl(backend.operator_set.gather)
def gather_impl(self, x, w, y, *, axis):
operand_shape = list(x.symval.shape)
indices_shape = list(w.symval.shape)
r = x.symval.ndim
q = w.symval.ndim
b = batch_dims
if b == 0:
offset_dims = list(range(1, q))
else:
offset_dims = list(range(b+1, q))
# index_vector_dim = b+1
offset_dims = list(range(1, q))
index_vector_dim = q-1
# N*(q-b-1) dim tensors
y_pre = None
# shape = (indices_shape[0:q-1]) + operand_shape[indices_shape[-1]:]
if indices_shape[-1] == r - b:
if indices_shape[-1] == r:
# Each scalar value corresponding to data[0:b-1,indices_slice]
slice_sizes = [1]*r
start_index_map = list(range(q))
Expand All @@ -658,21 +651,21 @@ def gather_nd_impl(self, x, w, y, *, batch_dims):

y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device)

elif indices_shape[-1] < r - b:
elif indices_shape[-1] < r:
# Each tensor slice corresponding to data[0:b-1, indices_slice , :]

slice_sizes = [*[1]*(r-1), *operand_shape[-1:]]
start_index_map = [i for i, s in enumerate(slice_sizes) if s==1 and i < q-b]

collapsed_slice_dims = []
# do_squeeze = False
for i in range(len(slice_sizes)):
if slice_sizes[i] == 1 and len(offset_dims)+len(collapsed_slice_dims) != r:
collapsed_slice_dims += [i]
# collapsed_slice_dims = [i for i, s in enumerate(slice_sizes) if s==1]


if (len(collapsed_slice_dims) != len(start_index_map)) and b==0:
y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device)


else:
raise ValueError
return f"""%{y.name}{'' if y_pre is None else '_'} = "stablehlo.gather"(%{x.name}, %{w.name}) {{
Expand Down
76 changes: 21 additions & 55 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,61 +942,27 @@ def T(self, cotangents, x, w, *, groups, stride, dilation, padding):
return [None, gL_w]


def gather(x, idx, dim: int):
assert idx.ndim == x.ndim, "x.ndim must equal idx.ndim"
assert all(
s >= i for s, i in zip(x.shape, idx.shape)
), "all dim of idx.shape must be smaller than x.shape"
if dim < 0:
dim += x.ndim
idx = idx.transpose(ax=dim, aw=0).unsqueeze(-1)
permarg = list(range(x.ndim))
permarg = (
permarg[1:dim] + [permarg[0]] + permarg[dim + 1 :] + [permarg[dim]]
if dim != 0
else permarg[1:] + [permarg[0]]
)
return (
(
(
idx
== slope.arange(
x.shape[dim],
dtype=slope.int32,
device=x.device,
)
)
* x.permute(*permarg)
.padslice(tuple(*[(0, sh) for sh in idx.shape[1:-1]], (0, x.shape[dim])))
.unsqueeze(0)
)
.sum(-1)
.transpose(0, dim)
)

@operator_set.register("gather_nd")
class GatherND(GeneralReduceOperator):
def args_fixer(self, x, w, *, batch_dims: int = 0):
return (x, w), dict(batch_dims=batch_dims)
@operator_set.register("gather")
class Gather(GeneralReduceOperator):
def args_fixer(self, x, w, *, axis: int = 0):
return (x, w), dict(axis=axis)

def typecheck(self, x, w, *, batch_dims: int):
def typecheck(self, x, w, *, axis: int):
r = x.ndim
q = w.ndim
b = batch_dims
b = axis
assert r > 0 and q > 0
assert x.shape[:b] == w.shape[:b]
assert b < min(q, r)
assert 1 <= w.shape[-1] <= r - b
assert w.shape[-1] <= r - b
# Step 1: Remove batch dimensions
data_shape = x.shape[b:]
indices_shape = w.shape[b:]
# Step 2: Calculate the dimensions of the output tensor
output_shape = [indices_shape[i] for i in range(len(indices_shape) - 1)]
# Step 3: Append the remaining dimensions from the data tensor
output_shape.extend(data_shape[indices_shape[-1]:])
# Step 4: Prepend batch dimensions to the output shape
shape = data_shape[:batch_dims] + tuple(output_shape)
# output_shape = [w.shape[b:][i] for i in range(len(w.shape[b:]) - 1)]
# output_shape.extend(x.shape[b:][w.shape[b:][-1]:])
# shape = x.shape[b:][:b] + tuple(output_shape)
bx = x.shape[b:]
bw = w.shape[b:]
shape = bx[:b] + bw[:len(bw) - 1] + bx[bw[-1]:]


return [SymbolicTensor(shape, x.dtype, x.device)]

Expand All @@ -1007,9 +973,9 @@ def vmap(self, dim_size, vals_in, dims_in, **params):
w = slope.core.VMapTrace.move_vmap_dim(w, dim_size, w_bdim, 0)
return [self(x, w, **params)], [x_bdim, w_bdim]

def jvp(self, primals, tangents, *, batch_dims):
def jvp(self, primals, tangents, *, axis):
(x, w), (x_dot, w_dot) = primals, tangents
return [self(x,w, batch_dims)], [self(x_dot,w, batch_dims)]
return [self(x,w, axis)], [self(x_dot,w, axis)]

def T(self, cotangents, x, w):
(gL_y,) = cotangents
Expand All @@ -1019,15 +985,15 @@ def T(self, cotangents, x, w):

@operator_set.register("scatter_nd")
class ScatterND(GeneralReduceOperator):
def args_fixer(self, x, w, u, *, batch_dims: int = 0):
return (x, w, u), dict(batch_dims=batch_dims)
def args_fixer(self, x, w, u, *, axis: int = 0):
return (x, w, u), dict(axis=axis)

def typecheck(self, x, w, u, *, batch_dims: int):
def typecheck(self, x, w, u, *, axis: int):
assert x.ndim > 0 and w.ndim > 0
assert u.ndim == w.ndim - 1
assert x.shape[:batch_dims] == w.shape[:batch_dims]
assert batch_dims < min(x.ndim, w.ndim)
assert 1 <= w.shape[-1] <= x.ndim - batch_dims
assert x.shape[:axis] == w.shape[:axis]
assert axis < min(x.ndim, w.ndim)
assert 1 <= w.shape[-1] <= x.ndim - axis
return [x]

def vmap(self, dim_size, vals_in, dims_in, **params):
Expand Down
30 changes: 15 additions & 15 deletions src/slope/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,21 +846,21 @@ def log_softmax(x, dim=-1):
logsumexp_x = x.exp().sum(dim, keepdim=True).log()
return x - logsumexp_x

@procedure_set.register()
def gather_nd(
params,
indices,
batch_dims=0):
def _gather_nd_single(params, indices):
idx = indices.moveaxis(-1, 0)
return params[idx]
assert batch_dims > 0, ('Negative `batch_dims` is currently unsupported.')
assert batch_dims == 0
gather_nd_ = functools.reduce(
lambda g, f: f(g), [slope.vmap] * int(batch_dims),
_gather_nd_single
) if batch_dims > 0 else _gather_nd_single
return gather_nd_(params, indices)
# @procedure_set.register()
# def gather_nd(
# params,
# indices,
# batch_dims=0):
# def _gather_nd_single(params, indices):
# idx = indices.moveaxis(-1, 0)
# return params[idx]
# assert batch_dims > 0, ('Negative `batch_dims` is currently unsupported.')
# assert batch_dims == 0
# gather_nd_ = functools.reduce(
# lambda g, f: f(g), [slope.vmap] * int(batch_dims),
# _gather_nd_single
# ) if batch_dims > 0 else _gather_nd_single
# return gather_nd_(params, indices)

# TODO:

Expand Down

0 comments on commit c31d67d

Please sign in to comment.