From 5576ed6180cb5f1460da117028554a714329b5d3 Mon Sep 17 00:00:00 2001 From: radenmuaz Date: Fri, 26 Jan 2024 02:09:11 +0800 Subject: [PATCH] gather getitem --- examples/simple/broadcast_binaryop.py | 6 ++ examples/simple/gather_nd.py | 56 +++++------ examples/simple/getitem.py | 8 ++ src/slope/backends/iree.py | 1 + src/slope/core.py | 4 +- src/slope/operators.py | 4 +- src/slope/procedures.py | 138 +++++++++++++------------- 7 files changed, 119 insertions(+), 98 deletions(-) create mode 100644 examples/simple/broadcast_binaryop.py create mode 100644 examples/simple/getitem.py diff --git a/examples/simple/broadcast_binaryop.py b/examples/simple/broadcast_binaryop.py new file mode 100644 index 0000000..9fec756 --- /dev/null +++ b/examples/simple/broadcast_binaryop.py @@ -0,0 +1,6 @@ +import slope + +x0 = slope.ones(2,1,1) +x1 = slope.ones(2,1,2) +y = x0*x1 +breakpoint() \ No newline at end of file diff --git a/examples/simple/gather_nd.py b/examples/simple/gather_nd.py index a01027f..4d00291 100644 --- a/examples/simple/gather_nd.py +++ b/examples/simple/gather_nd.py @@ -1,38 +1,38 @@ import slope print('#1') -# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) -# w = slope.tensor([[1,0],[0,1]], dtype=slope.int32) -# # w = slope.tensor([[0,0],[1,1]], dtype=slope.int32) -# print(f"{x=}") -# print(f"{w=}") -# y = x.gather_nd(w,0) -# print(f"{y=}") +x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) +w = slope.tensor([[1,0],[0,1]], dtype=slope.int32) +# w = slope.tensor([[0,0],[1,1]], dtype=slope.int32) +print(f"{x=}") +print(f"{w=}") +y = x.gather_nd(w,0) +print(f"{y=}") -# print('\n#2') -# x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) -# w = slope.tensor([[1],[0]]).cast(slope.int64) -# print(f"{x=}") -# print(f"{w=}") -# y = x.gather_nd(w) -# print(f"{y=}") +print('\n#2') +x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) +w = slope.tensor([[1],[0]]).cast(slope.int64) +print(f"{x=}") +print(f"{w=}") +y = x.gather_nd(w) +print(f"{y=}") -# print('\n#3') -# x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) -# w = slope.tensor([[0,1],[1,0]], dtype=slope.int32) -# print(f"{x=}") -# print(f"{w=}") -# y = x.gather_nd(w) -# print(f"{y=}") +print('\n#3') +x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) +w = slope.tensor([[0,1],[1,0]], dtype=slope.int32) +print(f"{x=}") +print(f"{w=}") +y = x.gather_nd(w) +print(f"{y=}") -# print('\n#4') -# x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) -# w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32) -# print(f"{x=}") -# print(f"{w=}") -# y = x.gather_nd(w) -# print(f"{y=}") +print('\n#4') +x = slope.tensor([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=slope.float32) +w = slope.tensor([[[0,1]],[[1,0]]], dtype=slope.int32) +print(f"{x=}") +print(f"{w=}") +y = x.gather_nd(w) +print(f"{y=}") print('\n#5') diff --git a/examples/simple/getitem.py b/examples/simple/getitem.py new file mode 100644 index 0000000..c7b82ad --- /dev/null +++ b/examples/simple/getitem.py @@ -0,0 +1,8 @@ +import slope + +x = slope.tensor([[0.,1.],[2.,3.]], dtype=slope.float32) +w = slope.tensor([1], dtype=slope.int32) + +# y = x.gather_nd(w) +y = x[w] +breakpoint() \ No newline at end of file diff --git a/src/slope/backends/iree.py b/src/slope/backends/iree.py index 55639d0..548b909 100644 --- a/src/slope/backends/iree.py +++ b/src/slope/backends/iree.py @@ -381,6 +381,7 @@ def less_impl(self, x, w, y): }} {as_mlir_sig((x.symval, w.symval), y.symval)} """ + @backend.set_impl(backend.operator_set.greater) def greater_impl(self, x, w, y): return f"""%{y.name} = "stablehlo.compare"(%{x.name}, %{w.name}) {{ diff --git a/src/slope/core.py b/src/slope/core.py index 96142e5..9dc3b23 100644 --- a/src/slope/core.py +++ b/src/slope/core.py @@ -196,7 +196,7 @@ class dtypes: name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()} mlir_dtype_map = {k.mlir: k for k in all_dtypes} mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()} - + @classmethod def is_int(cls, dtype): return dtype in (cls.uint8, cls.int8, cls.int32, cls.int64) @@ -541,6 +541,7 @@ def jvp(self, primals, tangents, **params): class BinaryOperator(Operator): boolean_output = False + def args_fixer(self, x, w, **params): if type(x) is UndefPrimal or type(w) is UndefPrimal: assert x.shape == w.shape @@ -619,6 +620,7 @@ def T(self, cotangents, x, w): gL_y = gL_y.cast(x.dtype) return [gL_y, None] + class ReduceOperator(Operator): def args_fixer(self, x, *, dim=None, keepdim=False): if dim is None: diff --git a/src/slope/operators.py b/src/slope/operators.py index a16e98c..3e1be24 100644 --- a/src/slope/operators.py +++ b/src/slope/operators.py @@ -226,17 +226,17 @@ def T(self, cotangents, x, w): class Equal(BinaryOperator): boolean_output = True + @operator_set.register("less") class Less(BinaryOperator): boolean_output = True + @operator_set.register("greater") class Greater(BinaryOperator): boolean_output = True - - @operator_set.register("max") class Max(ReduceOperator): def jvp(self, primals, tangents, *, dim, keepdim): diff --git a/src/slope/procedures.py b/src/slope/procedures.py index 89c1f32..43ab610 100644 --- a/src/slope/procedures.py +++ b/src/slope/procedures.py @@ -2,24 +2,14 @@ from slope.core import ProcedureSet, Tensor, dtypes import math import numpy as np -from typing import ( - Tuple, - List, - Dict, - Any, - Optional, - Sequence, - Union, - Iterator, - NamedTuple, - DefaultDict -) +from typing import Tuple, List, Dict, Any, Optional, Sequence, Union, Iterator, NamedTuple, DefaultDict from collections import defaultdict import functools -abs_py = abs - - +max_ = max +abs_ = abs +min_ = min +sum_ = sum procedure_set = ProcedureSet() @@ -69,13 +59,19 @@ def eye(dim: int, **kwargs): @procedure_set.register() def where(x, trueval, falseval): - cond = x != x.zeros_like() if not isinstance(trueval, Tensor): - trueval = slope.full((), trueval) + trueval = slope.full(trueval, device=x.device) if not isinstance(falseval, Tensor): - falseval = slope.full((), falseval) - cond = cond.cast(trueval.dtype) - return cond * trueval + (x.ones_like() - cond) * falseval + falseval = slope.full(falseval, device=x.device) + cond = x != x.zeros_like() + if not trueval.dtype is dtypes.bool: + cond = cond.cast(trueval.dtype) + return cond * trueval + (cond.ones_like() - cond) * falseval + else: + cond = cond.cast(slope.float32) + trueval = trueval.cast(slope.float32) + falseval = falseval.cast(slope.float32) + return (cond * trueval + (cond.ones_like() - cond) * falseval).cast(dtypes.bool) @procedure_set.register() @@ -108,14 +104,17 @@ def neg(x): def not_equal(x, w): return ~(x == w) + @procedure_set.register() def greater_equal(x, w): return ~(x < w) + @procedure_set.register() def less_equal(x, w): return ~(x > w) + @procedure_set.register() def minimum(x, w): return -x.maximum(-x, -w) @@ -215,16 +214,20 @@ def matmul(x, w): w = w.reshape((*w.shape[0:-2], 1, w.shape[-2], w.shape[-1])).transpose(-1, -2) return (x * w).sum(-1).reshape((*x.shape[0:-2], -1)) + @procedure_set.register() def getitem(x, indices) -> Tensor: # 1. indices normalization and validation # treat internal tuples and lists as Tensors and standardize indices to list type if isinstance(indices, list) and all(isinstance(s, int) for s in indices): - indices = [slope.tensor(indices, x.device,)] - elif isinstance(indices, (tuple, list)): indices = [ - slope.tensor(list(i), x.device) if isinstance(i, (tuple, list)) else i for i in indices + slope.tensor( + indices, + x.device, + ) ] + elif isinstance(indices, (tuple, list)): + indices = [slope.tensor(list(i), x.device) if isinstance(i, (tuple, list)) else i for i in indices] else: indices = [indices] @@ -273,9 +276,9 @@ def getitem(x, indices) -> Tensor: new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered) ret = x.padslice(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0)) - if any(abs_py(s) != 1 for s in strides): - strides = tuple(abs(s) for s in strides) - round_up = lambda num, amt: (num+amt-1)//amt * amt + if any(abs_(s) != 1 for s in strides): + strides = tuple(abs_(s) for s in strides) + round_up = lambda num, amt: (num + amt - 1) // amt * amt ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape))) ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape)))) ret = ret.padslice(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2]) @@ -285,7 +288,7 @@ def getitem(x, indices) -> Tensor: for dim in type_dim[None]: new_shape.insert(dim, 1) for dim in ( - dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])) + dims_collapsed := tuple(dim + sum_(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])) ): new_shape.pop(dim) @@ -293,52 +296,53 @@ def getitem(x, indices) -> Tensor: # 3. advanced indexing (copy) if type_dim[Tensor]: - # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim - def calc_dim(tensor_dim: int) -> int: - return ( - tensor_dim - - sum(1 for d in dims_collapsed if tensor_dim >= d) - + sum(1 for d in type_dim[None] if tensor_dim >= d) - ) - - # track tensor_dim and tensor_index using a dict - # calc_dim to get dim and use that to normalize the negative tensor indices - idx: Dict[int, Tensor] = { - (dim := calc_dim(td)): (tensor < 0).where(ret.shape[dim], 0) + tensor - for td, tensor in zip(type_dim[Tensor], tensor_index) - } + for i in tensor_index: + while i.ndim < ret.ndim: + i = i[None] + ret = ret.gather_nd(i) + return ret + ## impl like tinygrad tensor __getitem__: + # if type_dim[Tensor]: + # # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim + # def calc_dim(tensor_dim: int) -> int: + # return ( + # tensor_dim + # - sum_(1 for d in dims_collapsed if tensor_dim >= d) + # + sum_(1 for d in type_dim[None] if tensor_dim >= d) + # ) - # compute sum_dim, arange, and idx - max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys()) - sum_dim = tuple(d if n == 0 else d + max_idx_dim - n for n, d in enumerate(idx.keys())) + # # track tensor_dim and tensor_index using a dict + # # calc_dim to get dim and use that to normalize the negative tensor indices + # idx: Dict[int, Tensor] = { + # (dim := calc_dim(td)): (t < 0).where(t.full_like(ret.shape[dim]), t.zeros_like()) + t + # for td, t in zip(type_dim[Tensor], tensor_index) + # } + # # compute sum_dim, arange, and idx + # max_idx_dim, first_dim, last_dim = max_(i.ndim for i in idx.values()), min_(idx.keys()), max_(idx.keys()) + # sum_dim = tuple(d if n == 0 else d + max_idx_dim - n for n, d in enumerate(idx.keys())) # arange = [ - # Tensor.arange(ret.shape[d], requires_grad=False, device=x.device).reshape( + # slope.arange(ret.shape[d], device=x.device).reshape( # ret.shape[d : d + 1] + (1,) * (ret.ndim + max_idx_dim - n - sd - 1) # ) # for n, (sd, d) in enumerate(zip(sum_dim, idx.keys())) # ] # noqa: E501 - reshaped_idx = [ - i.reshape(i.shape + (1,) * (ret.ndim - first_dim - (n or 1))) for n, i in enumerate(idx.values()) - ] - ret_ = ret - ret = ret.reshape(ret.shape[: first_dim + 1] + (1,) * max_idx_dim + ret.shape[first_dim + 1 :]) - - # iteratively eq -> mul -> sum fancy index - for i, sd in zip(reshaped_idx, sum_dim): - breakpoint() - ret = ret.gather(i) - # # for a, i, sd in zip(arange, reshaped_idx, sum_dim): - # # ret = (a == i).mul(ret).sum(sd) - - # special permute case - if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim + 1)): - ret_dims = list(range(ret.ndim)) - ret = ret.permute( - ret_dims[first_dim : first_dim + max_idx_dim] - + ret_dims[:first_dim] - + ret_dims[first_dim + max_idx_dim :] - ) - return ret + # reshaped_idx = [ + # i.reshape(i.shape + (1,) * (ret.ndim - first_dim - (n or 1))) for n, i in enumerate(idx.values()) + # ] + # ret_ = ret + # ret = ret.reshape(ret.shape[: first_dim + 1] + (1,) * max_idx_dim + ret.shape[first_dim + 1 :]) + + # for a, i, sd in zip(arange, reshaped_idx, sum_dim): + # ret = (a == i).cast(ret.dtype).mul(ret).sum(sd) + + # # special permute case + # if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim + 1)): + # ret_dims = list(range(ret.ndim)) + # ret = ret.permute( + # ret_dims[first_dim : first_dim + max_idx_dim] + # + ret_dims[:first_dim] + # + ret_dims[first_dim + max_idx_dim :] + # ) @procedure_set.register()