diff --git a/examples/simple/gather.py b/examples/simple/gather.py index e95f02a..3308117 100644 --- a/examples/simple/gather.py +++ b/examples/simple/gather.py @@ -26,22 +26,22 @@ # print(f"{y=}") -print('\n#4') -x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) -w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32) -print(f"{x=}") -print(f"{w=}") -y = x.gather_nd(w) -print(f"{y=}") - -# print('\n#5') +# print('\n#4') # x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) -# w = slope.tensor([[1],[0]], dtype=slope.int32) +# w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32) # print(f"{x=}") # print(f"{w=}") -# y = x.gather_nd(w, 1) +# y = x.gather_nd(w) # print(f"{y=}") +print('\n#5') +x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) +w = slope.tensor([[1],[0]], dtype=slope.int32) +print(f"{x=}") +print(f"{w=}") +y = x.gather_nd(w, 1) +print(f"{y=}") + ''' func.func @main (%x0: tensor<2x2xf32>, %x1: tensor<2x2xi32>) -> (tensor<2xf32>) diff --git a/src/slope/backends/iree.py b/src/slope/backends/iree.py index f8685f2..83a15c2 100644 --- a/src/slope/backends/iree.py +++ b/src/slope/backends/iree.py @@ -633,23 +633,16 @@ def conv_impl(self, x, w, y, *, groups, stride, dilation, padding): } ''' -@backend.set_impl(backend.operator_set.gather_nd) -def gather_nd_impl(self, x, w, y, *, batch_dims): +@backend.set_impl(backend.operator_set.gather) +def gather_impl(self, x, w, y, *, axis): operand_shape = list(x.symval.shape) indices_shape = list(w.symval.shape) r = x.symval.ndim q = w.symval.ndim - b = batch_dims - if b == 0: - offset_dims = list(range(1, q)) - else: - offset_dims = list(range(b+1, q)) - # index_vector_dim = b+1 + offset_dims = list(range(1, q)) index_vector_dim = q-1 - # N*(q-b-1) dim tensors y_pre = None - # shape = (indices_shape[0:q-1]) + operand_shape[indices_shape[-1]:] - if indices_shape[-1] == r - b: + if indices_shape[-1] == r: # Each scalar value corresponding to data[0:b-1,indices_slice] slice_sizes = [1]*r start_index_map = list(range(q)) @@ -658,21 +651,21 @@ def gather_nd_impl(self, x, w, y, *, batch_dims): y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device) - elif indices_shape[-1] < r - b: + elif indices_shape[-1] < r: # Each tensor slice corresponding to data[0:b-1, indices_slice , :] + slice_sizes = [*[1]*(r-1), *operand_shape[-1:]] start_index_map = [i for i, s in enumerate(slice_sizes) if s==1 and i < q-b] collapsed_slice_dims = [] - # do_squeeze = False for i in range(len(slice_sizes)): if slice_sizes[i] == 1 and len(offset_dims)+len(collapsed_slice_dims) != r: collapsed_slice_dims += [i] - # collapsed_slice_dims = [i for i, s in enumerate(slice_sizes) if s==1] - + if (len(collapsed_slice_dims) != len(start_index_map)) and b==0: y_pre = SymbolicTensor(y.symval.shape + (1,), y.symval.dtype, y.symval.device) + else: raise ValueError return f"""%{y.name}{'' if y_pre is None else '_'} = "stablehlo.gather"(%{x.name}, %{w.name}) {{ diff --git a/src/slope/operators.py b/src/slope/operators.py index 5b5475b..1bb2ca2 100644 --- a/src/slope/operators.py +++ b/src/slope/operators.py @@ -942,61 +942,27 @@ def T(self, cotangents, x, w, *, groups, stride, dilation, padding): return [None, gL_w] -def gather(x, idx, dim: int): - assert idx.ndim == x.ndim, "x.ndim must equal idx.ndim" - assert all( - s >= i for s, i in zip(x.shape, idx.shape) - ), "all dim of idx.shape must be smaller than x.shape" - if dim < 0: - dim += x.ndim - idx = idx.transpose(ax=dim, aw=0).unsqueeze(-1) - permarg = list(range(x.ndim)) - permarg = ( - permarg[1:dim] + [permarg[0]] + permarg[dim + 1 :] + [permarg[dim]] - if dim != 0 - else permarg[1:] + [permarg[0]] - ) - return ( - ( - ( - idx - == slope.arange( - x.shape[dim], - dtype=slope.int32, - device=x.device, - ) - ) - * x.permute(*permarg) - .padslice(tuple(*[(0, sh) for sh in idx.shape[1:-1]], (0, x.shape[dim]))) - .unsqueeze(0) - ) - .sum(-1) - .transpose(0, dim) - ) - -@operator_set.register("gather_nd") -class GatherND(GeneralReduceOperator): - def args_fixer(self, x, w, *, batch_dims: int = 0): - return (x, w), dict(batch_dims=batch_dims) +@operator_set.register("gather") +class Gather(GeneralReduceOperator): + def args_fixer(self, x, w, *, axis: int = 0): + return (x, w), dict(axis=axis) - def typecheck(self, x, w, *, batch_dims: int): + def typecheck(self, x, w, *, axis: int): r = x.ndim q = w.ndim - b = batch_dims + b = axis assert r > 0 and q > 0 assert x.shape[:b] == w.shape[:b] assert b < min(q, r) assert 1 <= w.shape[-1] <= r - b assert w.shape[-1] <= r - b - # Step 1: Remove batch dimensions - data_shape = x.shape[b:] - indices_shape = w.shape[b:] - # Step 2: Calculate the dimensions of the output tensor - output_shape = [indices_shape[i] for i in range(len(indices_shape) - 1)] - # Step 3: Append the remaining dimensions from the data tensor - output_shape.extend(data_shape[indices_shape[-1]:]) - # Step 4: Prepend batch dimensions to the output shape - shape = data_shape[:batch_dims] + tuple(output_shape) + # output_shape = [w.shape[b:][i] for i in range(len(w.shape[b:]) - 1)] + # output_shape.extend(x.shape[b:][w.shape[b:][-1]:]) + # shape = x.shape[b:][:b] + tuple(output_shape) + bx = x.shape[b:] + bw = w.shape[b:] + shape = bx[:b] + bw[:len(bw) - 1] + bx[bw[-1]:] + return [SymbolicTensor(shape, x.dtype, x.device)] @@ -1007,9 +973,9 @@ def vmap(self, dim_size, vals_in, dims_in, **params): w = slope.core.VMapTrace.move_vmap_dim(w, dim_size, w_bdim, 0) return [self(x, w, **params)], [x_bdim, w_bdim] - def jvp(self, primals, tangents, *, batch_dims): + def jvp(self, primals, tangents, *, axis): (x, w), (x_dot, w_dot) = primals, tangents - return [self(x,w, batch_dims)], [self(x_dot,w, batch_dims)] + return [self(x,w, axis)], [self(x_dot,w, axis)] def T(self, cotangents, x, w): (gL_y,) = cotangents @@ -1019,15 +985,15 @@ def T(self, cotangents, x, w): @operator_set.register("scatter_nd") class ScatterND(GeneralReduceOperator): - def args_fixer(self, x, w, u, *, batch_dims: int = 0): - return (x, w, u), dict(batch_dims=batch_dims) + def args_fixer(self, x, w, u, *, axis: int = 0): + return (x, w, u), dict(axis=axis) - def typecheck(self, x, w, u, *, batch_dims: int): + def typecheck(self, x, w, u, *, axis: int): assert x.ndim > 0 and w.ndim > 0 assert u.ndim == w.ndim - 1 - assert x.shape[:batch_dims] == w.shape[:batch_dims] - assert batch_dims < min(x.ndim, w.ndim) - assert 1 <= w.shape[-1] <= x.ndim - batch_dims + assert x.shape[:axis] == w.shape[:axis] + assert axis < min(x.ndim, w.ndim) + assert 1 <= w.shape[-1] <= x.ndim - axis return [x] def vmap(self, dim_size, vals_in, dims_in, **params): diff --git a/src/slope/procedures.py b/src/slope/procedures.py index bb0c1b2..128011d 100644 --- a/src/slope/procedures.py +++ b/src/slope/procedures.py @@ -846,21 +846,21 @@ def log_softmax(x, dim=-1): logsumexp_x = x.exp().sum(dim, keepdim=True).log() return x - logsumexp_x -@procedure_set.register() -def gather_nd( - params, - indices, - batch_dims=0): - def _gather_nd_single(params, indices): - idx = indices.moveaxis(-1, 0) - return params[idx] - assert batch_dims > 0, ('Negative `batch_dims` is currently unsupported.') - assert batch_dims == 0 - gather_nd_ = functools.reduce( - lambda g, f: f(g), [slope.vmap] * int(batch_dims), - _gather_nd_single - ) if batch_dims > 0 else _gather_nd_single - return gather_nd_(params, indices) +# @procedure_set.register() +# def gather_nd( +# params, +# indices, +# batch_dims=0): +# def _gather_nd_single(params, indices): +# idx = indices.moveaxis(-1, 0) +# return params[idx] +# assert batch_dims > 0, ('Negative `batch_dims` is currently unsupported.') +# assert batch_dims == 0 +# gather_nd_ = functools.reduce( +# lambda g, f: f(g), [slope.vmap] * int(batch_dims), +# _gather_nd_single +# ) if batch_dims > 0 else _gather_nd_single +# return gather_nd_(params, indices) # TODO: