From dcb571aa0fa47bd4c97a4455e38c57bc0126d43e Mon Sep 17 00:00:00 2001 From: radenmuaz Date: Fri, 26 Jan 2024 01:03:10 +0800 Subject: [PATCH] comparison boolean --- examples/simple/comparison.py | 20 +++++ src/slope/backends/iree.py | 17 ++++ src/slope/core.py | 10 ++- src/slope/operators.py | 55 ++---------- src/slope/procedures.py | 162 +++------------------------------- 5 files changed, 65 insertions(+), 199 deletions(-) create mode 100644 examples/simple/comparison.py diff --git a/examples/simple/comparison.py b/examples/simple/comparison.py new file mode 100644 index 0000000..94304d6 --- /dev/null +++ b/examples/simple/comparison.py @@ -0,0 +1,20 @@ +import slope + +x = slope.ones(5) +w = slope.arange(5, dtype=x.dtype) +y_eq = (x == w).int() +y_ne = (x != w).int() +y_lt = (x < w).int() +y_le = (x <= w).int() +y_ge = (x >= w).int() +y_gt = (x > w).int() + + +print(f"{x=}") +print(f"{w=}") +print(f"x == w = {y_eq}") +print(f"x != w = {y_ne}") +print(f"x < w = {y_lt}") +print(f"x <= w = {y_le}") +print(f"x > w = {y_gt}") +print(f"x >= w = {y_ge}") diff --git a/src/slope/backends/iree.py b/src/slope/backends/iree.py index 58b1472..55639d0 100644 --- a/src/slope/backends/iree.py +++ b/src/slope/backends/iree.py @@ -373,6 +373,23 @@ def equal_impl(self, x, w, y): """ +@backend.set_impl(backend.operator_set.less) +def less_impl(self, x, w, y): + return f"""%{y.name} = "stablehlo.compare"(%{x.name}, %{w.name}) {{ + comparison_direction = #stablehlo, + compare_type = #stablehlo +}} {as_mlir_sig((x.symval, w.symval), y.symval)} +""" + +@backend.set_impl(backend.operator_set.greater) +def greater_impl(self, x, w, y): + return f"""%{y.name} = "stablehlo.compare"(%{x.name}, %{w.name}) {{ + comparison_direction = #stablehlo, + compare_type = #stablehlo +}} {as_mlir_sig((x.symval, w.symval), y.symval)} +""" + + @backend.set_impl(backend.operator_set.maximum) def maximum_impl(self, x, w, y): return f'%{y.name} = "stablehlo.maximum"(%{x.name}, %{w.name}) {as_mlir_sig((x.symval, w.symval), y.symval)}' diff --git a/src/slope/core.py b/src/slope/core.py index 1715236..96142e5 100644 --- a/src/slope/core.py +++ b/src/slope/core.py @@ -540,6 +540,7 @@ def jvp(self, primals, tangents, **params): class BinaryOperator(Operator): + boolean_output = False def args_fixer(self, x, w, **params): if type(x) is UndefPrimal or type(w) is UndefPrimal: assert x.shape == w.shape @@ -584,8 +585,8 @@ def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[Symb SymbolicTensor, ): raise TypeError - symx = SymbolicTensor.like(x) - symy = SymbolicTensor.like(y) + symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype) + symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype) if x.dtype != y.dtype: raise TypeError if symx == symy: @@ -612,6 +613,11 @@ def jvp(self, primals, tangents, **params): (x,), (x_dot,) = primals, tangents return [self(x, **params)], [self(x_dot, **params)] + def T(self, cotangents, x, w): + (gL_y,) = cotangents + if self.boolean_output: + gL_y = gL_y.cast(x.dtype) + return [gL_y, None] class ReduceOperator(Operator): def args_fixer(self, x, *, dim=None, keepdim=False): diff --git a/src/slope/operators.py b/src/slope/operators.py index dac85c1..a16e98c 100644 --- a/src/slope/operators.py +++ b/src/slope/operators.py @@ -224,54 +224,17 @@ def T(self, cotangents, x, w): @operator_set.register("equal") class Equal(BinaryOperator): - def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]: - # difference with default binary typecheck: force dtype bool - if not type(x) in (Tensor, SymbolicTensor) or not type(x) in ( - Tensor, - SymbolicTensor, - ): - raise TypeError - if x.dtype != y.dtype: - raise TypeError - symx = SymbolicTensor.like(x) - symy = SymbolicTensor.like(y) - if symx == symy: - return [SymbolicTensor(symx.shape, dtypes.bool, x.device)] - shape_delta = len(symx.shape) - len(symy.shape) - if shape_delta > 0: - symy = SymbolicTensor( - (1,) * shape_delta + symy.shape, - dtypes.bool, - x.device, - ) - elif shape_delta < 0: - x = x.reshape((1,) * -shape_delta + symx.shape) - symx = SymbolicTensor( - (1,) * -shape_delta + symx.shape, - dtypes.bool, - x.device, - ) - if symx == symy: - return [symx] - else: - shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)]) - if symx.shape != shape_ret: - symx = SymbolicTensor(shape_ret, dtypes.bool, x.device) - if symy.shape != shape_ret: - symy = SymbolicTensor(shape_ret, dtypes.bool, x.device) - if symx != symy: - raise TypeError - return [symx] + boolean_output = True + +@operator_set.register("less") +class Less(BinaryOperator): + boolean_output = True + +@operator_set.register("greater") +class Greater(BinaryOperator): + boolean_output = True - def jvp(self, primals, tangents): - (x, w), _ = primals, tangents - out_primal = x.equal(w) - return [out_primal], [slope.full(out_primal.shape, True, dtypes.bool, x.device)] - def T(self, cotangents, x, w): - (gL_y,) = cotangents - gL_y = gL_y.cast(x.dtype) - return [gL_y, None] @operator_set.register("max") diff --git a/src/slope/procedures.py b/src/slope/procedures.py index 6d7ebc6..89c1f32 100644 --- a/src/slope/procedures.py +++ b/src/slope/procedures.py @@ -69,13 +69,13 @@ def eye(dim: int, **kwargs): @procedure_set.register() def where(x, trueval, falseval): - cond = x != 0.0 + cond = x != x.zeros_like() if not isinstance(trueval, Tensor): trueval = slope.full((), trueval) if not isinstance(falseval, Tensor): falseval = slope.full((), falseval) cond = cond.cast(trueval.dtype) - return cond * trueval + (1.0 - cond) * falseval + return cond * trueval + (x.ones_like() - cond) * falseval @procedure_set.register() @@ -86,12 +86,12 @@ def mean(x, dim=None, keepdim=False): @procedure_set.register() def rsqrt(x): - return (slope.ones_like(x) / x).sqrt() + return (x.ones_like() / x).sqrt() @procedure_set.register() def cos(x): - return ((math.pi / 2) - x).sin() + return (x.full_like(math.pi / 2) - x).sin() @procedure_set.register() @@ -101,33 +101,20 @@ def tan(x): @procedure_set.register() def neg(x): - return full_like(x, -1) * x + return x.full_like(-1) * x @procedure_set.register() def not_equal(x, w): - return ~(x.equal(w)) - + return ~(x == w) @procedure_set.register() def greater_equal(x, w): - return x.maximum(w).equal(w) - + return ~(x < w) @procedure_set.register() def less_equal(x, w): - return x.minimum(w).equal(w) - - -@procedure_set.register() -def greater(x, w): - return 1.0 - (x <= w) - - -@procedure_set.register() -def less(x, w): - return 1.0 - (x >= w) - + return ~(x > w) @procedure_set.register() def minimum(x, w): @@ -228,131 +215,6 @@ def matmul(x, w): w = w.reshape((*w.shape[0:-2], 1, w.shape[-2], w.shape[-1])).transpose(-1, -2) return (x * w).sum(-1).reshape((*x.shape[0:-2], -1)) - -@procedure_set.register() -def old_getitem(x, val): - # Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] - def normalize_int(e, i, dim_sz): - if -dim_sz <= e < dim_sz: - return e if e != -1 else dim_sz - 1 - raise IndexError(f"index {e} is out of bounds for dimension {i} with size {x.shape[i]}") - - orig_slices = list(val) if isinstance(val, tuple) else [val] - count = defaultdict(list) - for i, v in enumerate(orig_slices): - count[type(v) if not isinstance(v, slope.core.Tensor) else "tensor"] += [i] - - if (num_slices := len(count[int]) + len(count[slice]) + len(count["tensor"])) > len(x.shape): - raise IndexError(f"too many indices for tensor of dimension {len(x.shape)}") - if len(ellipsis_found := count[type(Ellipsis)]) > 1: - raise IndexError("an index can only have a single ellipsis ('...')") - - ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices) - orig_slices[ellipsis_idx : ellipsis_idx + 1] = [slice(None)] * (len(x.shape) - num_slices) - - valid_slices = [v for v in orig_slices if v is not None] - valid_slices = [ - v - if isinstance(v, slice) - else slice(y_ := normalize_int(v, i, dim_sz), y_ + 1) - if isinstance(v, int) - else slice(None) - for i, (v, dim_sz) in enumerate(zip(valid_slices, x.shape)) - ] - - start, stop, strides = ( - zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, x.shape)]) else ((), (), ()) - ) - new_slice = tuple((s, e) if st > 0 else (e + 1, s + 1) for s, e, st in zip(start, stop, strides)) - sliced_tensor = x.padslice(new_slice).flip(dim=tuple([i for i, s in enumerate(strides) if s < 0])) - new_shape = sliced_tensor.shape - if any(abs_py(s) != 1 for s in strides): - strides = tuple(abs_py(s) for s in strides) - # Pad: add pad at the end: [dim_sz] -> [dim_sz_padded] - padded_tensor = sliced_tensor.pad( - tuple( - (0, s - (dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)[::-1] - ) - ) - # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] - reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))) - new_shape = reshaped_tensor.shape[::2] - # Shrink: do [:, 0] - sliced_tensor = reshaped_tensor.padslice(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))) - - final_shape, it_shape, dim, tensors, dim_collapsed = ( - [], - iter(new_shape), - [], - [], - 0, - ) - for i, s in enumerate(orig_slices): - if s is None: - final_shape.append(1) - else: # s is int or slice or Tensor - dim_shape = next(it_shape) - if isinstance(s, int): - dim_collapsed += 1 - else: - final_shape.append(dim_shape) - if isinstance(s, slope.core.Tensor): - tensors.append(s) - dim.append(i - dim_collapsed) - ret = sliced_tensor.reshape(tuple(final_shape)) - - if tensors: # Fancy/tensor indexing - # normalize idx - idx = [t.sign().neg().relu() * ret.shape[d] + t for d, t in zip(dim, tensors)] - max_dim = max(i.ndim for i in idx) - # compute sum_dim, arange, and idx - sum_dim = [d if n == 0 else d + max_dim - n for n, d in enumerate(dim)] - slice_arange = [ - slope.arange( - ret.shape[d], - dtype=slope.int32, - requires_grad=False, - device=x.device, - ).reshape( - *[1] * sd, - ret.shape[d], - *[1] * (ret.ndim + max_dim - n - sd - 1), - ) - for n, (sd, d) in enumerate(zip(sum_dim, dim)) - ] - first_idx = [ - idx[0].reshape( - *[1] * dim[0], - *[1] * (1 + max_dim - idx[0].ndim), - *idx[0].shape, - *[1] * (ret.ndim - dim[0] - 1), - ) - ] - rest_idx = [ - i.reshape( - *[1] * dim[0], - *[1] * (max_dim - i.ndim), - *i.shape, - *[1] * (ret.ndim - dim[0] - n), - ) - for n, i in enumerate(idx[1:], 1) - ] - idx = first_idx + rest_idx - ret = ret.reshape( - *ret.shape[: sum_dim[0] + 1], - *[1] * max_dim, - *ret.shape[sum_dim[0] + 1 :], - ) - # iteratively fancy index - for a, i, sd in zip(slice_arange, idx, sum_dim): - ret = (a == i).mul(ret).sum(sd) - # special permute case - if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1] + 1)): - ret_dims = list(range(ret.ndim)) - ret = ret.permute(ret_dims[dim[0] : dim[0] + max_dim] + ret_dims[: dim[0]] + ret_dims[dim[0] + max_dim :]) - return ret - - @procedure_set.register() def getitem(x, indices) -> Tensor: # 1. indices normalization and validation @@ -458,17 +320,15 @@ def calc_dim(tensor_dim: int) -> int: reshaped_idx = [ i.reshape(i.shape + (1,) * (ret.ndim - first_dim - (n or 1))) for n, i in enumerate(idx.values()) ] + ret_ = ret ret = ret.reshape(ret.shape[: first_dim + 1] + (1,) * max_idx_dim + ret.shape[first_dim + 1 :]) # iteratively eq -> mul -> sum fancy index for i, sd in zip(reshaped_idx, sum_dim): breakpoint() ret = ret.gather(i) - # try: - # # for a, i, sd in zip(arange, reshaped_idx, sum_dim): - # # ret = (a == i).mul(ret).sum(sd) - # except AssertionError as exc: - # raise IndexError("cannot broadcast indices") from exc + # # for a, i, sd in zip(arange, reshaped_idx, sum_dim): + # # ret = (a == i).mul(ret).sum(sd) # special permute case if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim + 1)):