Skip to content

Commit

Permalink
gather getitem
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 25, 2024
1 parent dcb571a commit 5576ed6
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 98 deletions.
6 changes: 6 additions & 0 deletions examples/simple/broadcast_binaryop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import slope

x0 = slope.ones(2,1,1)
x1 = slope.ones(2,1,2)
y = x0*x1
breakpoint()
56 changes: 28 additions & 28 deletions examples/simple/gather_nd.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import slope

print('#1')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# # w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w,0)
# print(f"{y=}")
x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
w = slope.tensor([[1,0],[0,1]], dtype=slope.int32)
# w = slope.tensor([[0,0],[1,1]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w,0)
print(f"{y=}")

# print('\n#2')
# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
# w = slope.tensor([[1],[0]]).cast(slope.int64)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# print(f"{y=}")
print('\n#2')
x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
w = slope.tensor([[1],[0]]).cast(slope.int64)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w)
print(f"{y=}")

# print('\n#3')
# x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
# w = slope.tensor([[0,1],[1,0]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# print(f"{y=}")
print('\n#3')
x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
w = slope.tensor([[0,1],[1,0]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w)
print(f"{y=}")


# print('\n#4')
# x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
# w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32)
# print(f"{x=}")
# print(f"{w=}")
# y = x.gather_nd(w)
# print(f"{y=}")
print('\n#4')
x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32)
w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32)
print(f"{x=}")
print(f"{w=}")
y = x.gather_nd(w)
print(f"{y=}")


print('\n#5')
Expand Down
8 changes: 8 additions & 0 deletions examples/simple/getitem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import slope

x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32)
w = slope.tensor([1], dtype=slope.int32)

# y = x.gather_nd(w)
y = x[w]
breakpoint()
1 change: 1 addition & 0 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def less_impl(self, x, w, y):
}} {as_mlir_sig((x.symval, w.symval), y.symval)}
"""


@backend.set_impl(backend.operator_set.greater)
def greater_impl(self, x, w, y):
return f"""%{y.name} = "stablehlo.compare"(%{x.name}, %{w.name}) {{
Expand Down
4 changes: 3 additions & 1 deletion src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class dtypes:
name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()}
mlir_dtype_map = {k.mlir: k for k in all_dtypes}
mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()}

@classmethod
def is_int(cls, dtype):
return dtype in (cls.uint8, cls.int8, cls.int32, cls.int64)
Expand Down Expand Up @@ -541,6 +541,7 @@ def jvp(self, primals, tangents, **params):

class BinaryOperator(Operator):
boolean_output = False

def args_fixer(self, x, w, **params):
if type(x) is UndefPrimal or type(w) is UndefPrimal:
assert x.shape == w.shape
Expand Down Expand Up @@ -619,6 +620,7 @@ def T(self, cotangents, x, w):
gL_y = gL_y.cast(x.dtype)
return [gL_y, None]


class ReduceOperator(Operator):
def args_fixer(self, x, *, dim=None, keepdim=False):
if dim is None:
Expand Down
4 changes: 2 additions & 2 deletions src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,17 @@ def T(self, cotangents, x, w):
class Equal(BinaryOperator):
boolean_output = True


@operator_set.register("less")
class Less(BinaryOperator):
boolean_output = True


@operator_set.register("greater")
class Greater(BinaryOperator):
boolean_output = True




@operator_set.register("max")
class Max(ReduceOperator):
def jvp(self, primals, tangents, *, dim, keepdim):
Expand Down
138 changes: 71 additions & 67 deletions src/slope/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,14 @@
from slope.core import ProcedureSet, Tensor, dtypes
import math
import numpy as np
from typing import (
Tuple,
List,
Dict,
Any,
Optional,
Sequence,
Union,
Iterator,
NamedTuple,
DefaultDict
)
from typing import Tuple, List, Dict, Any, Optional, Sequence, Union, Iterator, NamedTuple, DefaultDict
from collections import defaultdict
import functools

abs_py = abs


max_ = max
abs_ = abs
min_ = min
sum_ = sum
procedure_set = ProcedureSet()


Expand Down Expand Up @@ -69,13 +59,19 @@ def eye(dim: int, **kwargs):

@procedure_set.register()
def where(x, trueval, falseval):
cond = x != x.zeros_like()
if not isinstance(trueval, Tensor):
trueval = slope.full((), trueval)
trueval = slope.full(trueval, device=x.device)
if not isinstance(falseval, Tensor):
falseval = slope.full((), falseval)
cond = cond.cast(trueval.dtype)
return cond * trueval + (x.ones_like() - cond) * falseval
falseval = slope.full(falseval, device=x.device)
cond = x != x.zeros_like()
if not trueval.dtype is dtypes.bool:
cond = cond.cast(trueval.dtype)
return cond * trueval + (cond.ones_like() - cond) * falseval
else:
cond = cond.cast(slope.float32)
trueval = trueval.cast(slope.float32)
falseval = falseval.cast(slope.float32)
return (cond * trueval + (cond.ones_like() - cond) * falseval).cast(dtypes.bool)


@procedure_set.register()
Expand Down Expand Up @@ -108,14 +104,17 @@ def neg(x):
def not_equal(x, w):
return ~(x == w)


@procedure_set.register()
def greater_equal(x, w):
return ~(x < w)


@procedure_set.register()
def less_equal(x, w):
return ~(x > w)


@procedure_set.register()
def minimum(x, w):
return -x.maximum(-x, -w)
Expand Down Expand Up @@ -215,16 +214,20 @@ def matmul(x, w):
w = w.reshape((*w.shape[0:-2], 1, w.shape[-2], w.shape[-1])).transpose(-1, -2)
return (x * w).sum(-1).reshape((*x.shape[0:-2], -1))


@procedure_set.register()
def getitem(x, indices) -> Tensor:
# 1. indices normalization and validation
# treat internal tuples and lists as Tensors and standardize indices to list type
if isinstance(indices, list) and all(isinstance(s, int) for s in indices):
indices = [slope.tensor(indices, x.device,)]
elif isinstance(indices, (tuple, list)):
indices = [
slope.tensor(list(i), x.device) if isinstance(i, (tuple, list)) else i for i in indices
slope.tensor(
indices,
x.device,
)
]
elif isinstance(indices, (tuple, list)):
indices = [slope.tensor(list(i), x.device) if isinstance(i, (tuple, list)) else i for i in indices]
else:
indices = [indices]

Expand Down Expand Up @@ -273,9 +276,9 @@ def getitem(x, indices) -> Tensor:

new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
ret = x.padslice(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
if any(abs_py(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
round_up = lambda num, amt: (num+amt-1)//amt * amt
if any(abs_(s) != 1 for s in strides):
strides = tuple(abs_(s) for s in strides)
round_up = lambda num, amt: (num + amt - 1) // amt * amt
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
ret = ret.padslice(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
Expand All @@ -285,60 +288,61 @@ def getitem(x, indices) -> Tensor:
for dim in type_dim[None]:
new_shape.insert(dim, 1)
for dim in (
dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))
dims_collapsed := tuple(dim + sum_(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))
):
new_shape.pop(dim)

ret = ret.reshape(new_shape)

# 3. advanced indexing (copy)
if type_dim[Tensor]:
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
def calc_dim(tensor_dim: int) -> int:
return (
tensor_dim
- sum(1 for d in dims_collapsed if tensor_dim >= d)
+ sum(1 for d in type_dim[None] if tensor_dim >= d)
)

# track tensor_dim and tensor_index using a dict
# calc_dim to get dim and use that to normalize the negative tensor indices
idx: Dict[int, Tensor] = {
(dim := calc_dim(td)): (tensor < 0).where(ret.shape[dim], 0) + tensor
for td, tensor in zip(type_dim[Tensor], tensor_index)
}
for i in tensor_index:
while i.ndim < ret.ndim:
i = i[None]
ret = ret.gather_nd(i)
return ret
## impl like tinygrad tensor __getitem__:
# if type_dim[Tensor]:
# # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
# def calc_dim(tensor_dim: int) -> int:
# return (
# tensor_dim
# - sum_(1 for d in dims_collapsed if tensor_dim >= d)
# + sum_(1 for d in type_dim[None] if tensor_dim >= d)
# )

# compute sum_dim, arange, and idx
max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys())
sum_dim = tuple(d if n == 0 else d + max_idx_dim - n for n, d in enumerate(idx.keys()))
# # track tensor_dim and tensor_index using a dict
# # calc_dim to get dim and use that to normalize the negative tensor indices
# idx: Dict[int, Tensor] = {
# (dim := calc_dim(td)): (t < 0).where(t.full_like(ret.shape[dim]), t.zeros_like()) + t
# for td, t in zip(type_dim[Tensor], tensor_index)
# }
# # compute sum_dim, arange, and idx
# max_idx_dim, first_dim, last_dim = max_(i.ndim for i in idx.values()), min_(idx.keys()), max_(idx.keys())
# sum_dim = tuple(d if n == 0 else d + max_idx_dim - n for n, d in enumerate(idx.keys()))
# arange = [
# Tensor.arange(ret.shape[d], requires_grad=False, device=x.device).reshape(
# slope.arange(ret.shape[d], device=x.device).reshape(
# ret.shape[d : d + 1] + (1,) * (ret.ndim + max_idx_dim - n - sd - 1)
# )
# for n, (sd, d) in enumerate(zip(sum_dim, idx.keys()))
# ] # noqa: E501
reshaped_idx = [
i.reshape(i.shape + (1,) * (ret.ndim - first_dim - (n or 1))) for n, i in enumerate(idx.values())
]
ret_ = ret
ret = ret.reshape(ret.shape[: first_dim + 1] + (1,) * max_idx_dim + ret.shape[first_dim + 1 :])

# iteratively eq -> mul -> sum fancy index
for i, sd in zip(reshaped_idx, sum_dim):
breakpoint()
ret = ret.gather(i)
# # for a, i, sd in zip(arange, reshaped_idx, sum_dim):
# # ret = (a == i).mul(ret).sum(sd)

# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim + 1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(
ret_dims[first_dim : first_dim + max_idx_dim]
+ ret_dims[:first_dim]
+ ret_dims[first_dim + max_idx_dim :]
)
return ret
# reshaped_idx = [
# i.reshape(i.shape + (1,) * (ret.ndim - first_dim - (n or 1))) for n, i in enumerate(idx.values())
# ]
# ret_ = ret
# ret = ret.reshape(ret.shape[: first_dim + 1] + (1,) * max_idx_dim + ret.shape[first_dim + 1 :])

# for a, i, sd in zip(arange, reshaped_idx, sum_dim):
# ret = (a == i).cast(ret.dtype).mul(ret).sum(sd)

# # special permute case
# if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim + 1)):
# ret_dims = list(range(ret.ndim))
# ret = ret.permute(
# ret_dims[first_dim : first_dim + max_idx_dim]
# + ret_dims[:first_dim]
# + ret_dims[first_dim + max_idx_dim :]
# )


@procedure_set.register()
Expand Down

0 comments on commit 5576ed6

Please sign in to comment.