Skip to content

Commit

Permalink
comparison boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 25, 2024
1 parent 0481a95 commit dcb571a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 199 deletions.
20 changes: 20 additions & 0 deletions examples/simple/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import slope

x = slope.ones(5)
w = slope.arange(5, dtype=x.dtype)
y_eq = (x == w).int()
y_ne = (x != w).int()
y_lt = (x < w).int()
y_le = (x <= w).int()
y_ge = (x >= w).int()
y_gt = (x > w).int()


print(f"{x=}")
print(f"{w=}")
print(f"x == w = {y_eq}")
print(f"x != w = {y_ne}")
print(f"x < w = {y_lt}")
print(f"x <= w = {y_le}")
print(f"x > w = {y_gt}")
print(f"x >= w = {y_ge}")
17 changes: 17 additions & 0 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,23 @@ def equal_impl(self, x, w, y):
"""


@backend.set_impl(backend.operator_set.less)
def less_impl(self, x, w, y):
return f"""%{y.name} = "stablehlo.compare"(%{x.name}, %{w.name}) {{
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
}} {as_mlir_sig((x.symval, w.symval), y.symval)}
"""

@backend.set_impl(backend.operator_set.greater)
def greater_impl(self, x, w, y):
return f"""%{y.name} = "stablehlo.compare"(%{x.name}, %{w.name}) {{
comparison_direction = #stablehlo<comparison_direction GT>,
compare_type = #stablehlo<comparison_type FLOAT>
}} {as_mlir_sig((x.symval, w.symval), y.symval)}
"""


@backend.set_impl(backend.operator_set.maximum)
def maximum_impl(self, x, w, y):
return f'%{y.name} = "stablehlo.maximum"(%{x.name}, %{w.name}) {as_mlir_sig((x.symval, w.symval), y.symval)}'
Expand Down
10 changes: 8 additions & 2 deletions src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def jvp(self, primals, tangents, **params):


class BinaryOperator(Operator):
boolean_output = False
def args_fixer(self, x, w, **params):
if type(x) is UndefPrimal or type(w) is UndefPrimal:
assert x.shape == w.shape
Expand Down Expand Up @@ -584,8 +585,8 @@ def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[Symb
SymbolicTensor,
):
raise TypeError
symx = SymbolicTensor.like(x)
symy = SymbolicTensor.like(y)
symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype)
symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype)
if x.dtype != y.dtype:
raise TypeError
if symx == symy:
Expand All @@ -612,6 +613,11 @@ def jvp(self, primals, tangents, **params):
(x,), (x_dot,) = primals, tangents
return [self(x, **params)], [self(x_dot, **params)]

def T(self, cotangents, x, w):
(gL_y,) = cotangents
if self.boolean_output:
gL_y = gL_y.cast(x.dtype)
return [gL_y, None]

class ReduceOperator(Operator):
def args_fixer(self, x, *, dim=None, keepdim=False):
Expand Down
55 changes: 9 additions & 46 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,54 +224,17 @@ def T(self, cotangents, x, w):

@operator_set.register("equal")
class Equal(BinaryOperator):
def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]:
# difference with default binary typecheck: force dtype bool
if not type(x) in (Tensor, SymbolicTensor) or not type(x) in (
Tensor,
SymbolicTensor,
):
raise TypeError
if x.dtype != y.dtype:
raise TypeError
symx = SymbolicTensor.like(x)
symy = SymbolicTensor.like(y)
if symx == symy:
return [SymbolicTensor(symx.shape, dtypes.bool, x.device)]
shape_delta = len(symx.shape) - len(symy.shape)
if shape_delta > 0:
symy = SymbolicTensor(
(1,) * shape_delta + symy.shape,
dtypes.bool,
x.device,
)
elif shape_delta < 0:
x = x.reshape((1,) * -shape_delta + symx.shape)
symx = SymbolicTensor(
(1,) * -shape_delta + symx.shape,
dtypes.bool,
x.device,
)
if symx == symy:
return [symx]
else:
shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)])
if symx.shape != shape_ret:
symx = SymbolicTensor(shape_ret, dtypes.bool, x.device)
if symy.shape != shape_ret:
symy = SymbolicTensor(shape_ret, dtypes.bool, x.device)
if symx != symy:
raise TypeError
return [symx]
boolean_output = True

@operator_set.register("less")
class Less(BinaryOperator):
boolean_output = True

@operator_set.register("greater")
class Greater(BinaryOperator):
boolean_output = True

def jvp(self, primals, tangents):
(x, w), _ = primals, tangents
out_primal = x.equal(w)
return [out_primal], [slope.full(out_primal.shape, True, dtypes.bool, x.device)]

def T(self, cotangents, x, w):
(gL_y,) = cotangents
gL_y = gL_y.cast(x.dtype)
return [gL_y, None]


@operator_set.register("max")
Expand Down
162 changes: 11 additions & 151 deletions src/slope/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def eye(dim: int, **kwargs):

@procedure_set.register()
def where(x, trueval, falseval):
cond = x != 0.0
cond = x != x.zeros_like()
if not isinstance(trueval, Tensor):
trueval = slope.full((), trueval)
if not isinstance(falseval, Tensor):
falseval = slope.full((), falseval)
cond = cond.cast(trueval.dtype)
return cond * trueval + (1.0 - cond) * falseval
return cond * trueval + (x.ones_like() - cond) * falseval


@procedure_set.register()
Expand All @@ -86,12 +86,12 @@ def mean(x, dim=None, keepdim=False):

@procedure_set.register()
def rsqrt(x):
return (slope.ones_like(x) / x).sqrt()
return (x.ones_like() / x).sqrt()


@procedure_set.register()
def cos(x):
return ((math.pi / 2) - x).sin()
return (x.full_like(math.pi / 2) - x).sin()


@procedure_set.register()
Expand All @@ -101,33 +101,20 @@ def tan(x):

@procedure_set.register()
def neg(x):
return full_like(x, -1) * x
return x.full_like(-1) * x


@procedure_set.register()
def not_equal(x, w):
return ~(x.equal(w))

return ~(x == w)

@procedure_set.register()
def greater_equal(x, w):
return x.maximum(w).equal(w)

return ~(x < w)

@procedure_set.register()
def less_equal(x, w):
return x.minimum(w).equal(w)


@procedure_set.register()
def greater(x, w):
return 1.0 - (x <= w)


@procedure_set.register()
def less(x, w):
return 1.0 - (x >= w)

return ~(x > w)

@procedure_set.register()
def minimum(x, w):
Expand Down Expand Up @@ -228,131 +215,6 @@ def matmul(x, w):
w = w.reshape((*w.shape[0:-2], 1, w.shape[-2], w.shape[-1])).transpose(-1, -2)
return (x * w).sum(-1).reshape((*x.shape[0:-2], -1))


@procedure_set.register()
def old_getitem(x, val):
# Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz:
return e if e != -1 else dim_sz - 1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {x.shape[i]}")

orig_slices = list(val) if isinstance(val, tuple) else [val]
count = defaultdict(list)
for i, v in enumerate(orig_slices):
count[type(v) if not isinstance(v, slope.core.Tensor) else "tensor"] += [i]

if (num_slices := len(count[int]) + len(count[slice]) + len(count["tensor"])) > len(x.shape):
raise IndexError(f"too many indices for tensor of dimension {len(x.shape)}")
if len(ellipsis_found := count[type(Ellipsis)]) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")

ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices)
orig_slices[ellipsis_idx : ellipsis_idx + 1] = [slice(None)] * (len(x.shape) - num_slices)

valid_slices = [v for v in orig_slices if v is not None]
valid_slices = [
v
if isinstance(v, slice)
else slice(y_ := normalize_int(v, i, dim_sz), y_ + 1)
if isinstance(v, int)
else slice(None)
for i, (v, dim_sz) in enumerate(zip(valid_slices, x.shape))
]

start, stop, strides = (
zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, x.shape)]) else ((), (), ())
)
new_slice = tuple((s, e) if st > 0 else (e + 1, s + 1) for s, e, st in zip(start, stop, strides))
sliced_tensor = x.padslice(new_slice).flip(dim=tuple([i for i, s in enumerate(strides) if s < 0]))
new_shape = sliced_tensor.shape
if any(abs_py(s) != 1 for s in strides):
strides = tuple(abs_py(s) for s in strides)
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
padded_tensor = sliced_tensor.pad(
tuple(
(0, s - (dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)[::-1]
)
)
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
new_shape = reshaped_tensor.shape[::2]
# Shrink: do [:, 0]
sliced_tensor = reshaped_tensor.padslice(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))

final_shape, it_shape, dim, tensors, dim_collapsed = (
[],
iter(new_shape),
[],
[],
0,
)
for i, s in enumerate(orig_slices):
if s is None:
final_shape.append(1)
else: # s is int or slice or Tensor
dim_shape = next(it_shape)
if isinstance(s, int):
dim_collapsed += 1
else:
final_shape.append(dim_shape)
if isinstance(s, slope.core.Tensor):
tensors.append(s)
dim.append(i - dim_collapsed)
ret = sliced_tensor.reshape(tuple(final_shape))

if tensors: # Fancy/tensor indexing
# normalize idx
idx = [t.sign().neg().relu() * ret.shape[d] + t for d, t in zip(dim, tensors)]
max_dim = max(i.ndim for i in idx)
# compute sum_dim, arange, and idx
sum_dim = [d if n == 0 else d + max_dim - n for n, d in enumerate(dim)]
slice_arange = [
slope.arange(
ret.shape[d],
dtype=slope.int32,
requires_grad=False,
device=x.device,
).reshape(
*[1] * sd,
ret.shape[d],
*[1] * (ret.ndim + max_dim - n - sd - 1),
)
for n, (sd, d) in enumerate(zip(sum_dim, dim))
]
first_idx = [
idx[0].reshape(
*[1] * dim[0],
*[1] * (1 + max_dim - idx[0].ndim),
*idx[0].shape,
*[1] * (ret.ndim - dim[0] - 1),
)
]
rest_idx = [
i.reshape(
*[1] * dim[0],
*[1] * (max_dim - i.ndim),
*i.shape,
*[1] * (ret.ndim - dim[0] - n),
)
for n, i in enumerate(idx[1:], 1)
]
idx = first_idx + rest_idx
ret = ret.reshape(
*ret.shape[: sum_dim[0] + 1],
*[1] * max_dim,
*ret.shape[sum_dim[0] + 1 :],
)
# iteratively fancy index
for a, i, sd in zip(slice_arange, idx, sum_dim):
ret = (a == i).mul(ret).sum(sd)
# special permute case
if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1] + 1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(ret_dims[dim[0] : dim[0] + max_dim] + ret_dims[: dim[0]] + ret_dims[dim[0] + max_dim :])
return ret


@procedure_set.register()
def getitem(x, indices) -> Tensor:
# 1. indices normalization and validation
Expand Down Expand Up @@ -458,17 +320,15 @@ def calc_dim(tensor_dim: int) -> int:
reshaped_idx = [
i.reshape(i.shape + (1,) * (ret.ndim - first_dim - (n or 1))) for n, i in enumerate(idx.values())
]
ret_ = ret
ret = ret.reshape(ret.shape[: first_dim + 1] + (1,) * max_idx_dim + ret.shape[first_dim + 1 :])

# iteratively eq -> mul -> sum fancy index
for i, sd in zip(reshaped_idx, sum_dim):
breakpoint()
ret = ret.gather(i)
# try:
# # for a, i, sd in zip(arange, reshaped_idx, sum_dim):
# # ret = (a == i).mul(ret).sum(sd)
# except AssertionError as exc:
# raise IndexError("cannot broadcast indices") from exc
# # for a, i, sd in zip(arange, reshaped_idx, sum_dim):
# # ret = (a == i).mul(ret).sum(sd)

# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim + 1)):
Expand Down

0 comments on commit dcb571a

Please sign in to comment.