Skip to content

Commit

Permalink
argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 21, 2024
1 parent 2e5486e commit 82f79f4
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 57 deletions.
16 changes: 8 additions & 8 deletions examples/nn/mnist_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ def train_step(model, batch, optimizer):
return loss, new_model, new_optimizer


# @slope.jit
@slope.jit
def test_all(model, x, y):
out = model(x)
# y_hat = out.argmax(-1)
# corrects = (y_hat == y).cast(slope.float32)
# accuracy = corrects.mean()
y_hat = out.argmax(-1)
corrects = (y_hat == y).cast(slope.float32)
accuracy = corrects.mean()

y_hat = np.argmax(out.numpy() ,-1)
corrects = (y_hat == y.numpy()).astype(np.float32)
accuracy = np.mean(corrects)
# y_hat = np.argmax(out.numpy() ,-1)
# corrects = (y_hat == y.numpy()).astype(np.float32)
# accuracy = np.mean(corrects)

return accuracy

Expand Down Expand Up @@ -93,4 +93,4 @@ def data_stream():

test_acc = test_all(model, x_test, y_test)
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Test set accuracy {test_acc:0.2f}")
print(f"Test set accuracy {test_acc.numpy():0.2f}")
8 changes: 8 additions & 0 deletions examples/simple/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import slope
# y1 = slope.arange(5)
# y2 = slope.arange(0,10,2)
# y3 = slope.arange(10,0,-1)
# print(y1)
# print(y2)
y3 = slope.arange(10)
print(y3.argmax())
13 changes: 4 additions & 9 deletions examples/simple/symbolic_maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@
x1 = slope.tensor((1,))
x2 = slope.tensor((2,))
y1 = x1+x2

with slope.symbolic_run():
sym_y = x1+x2
sym_y1 = x1+x2
y2 = x1+x2
print(f"{y1=}")
print(f"{sym_y=}")
print(f"{y2=}")
# x1 = slope.symbolic_tensor((1,))
# x2 = slope.symbolic_tensor((1,))
# with slope.symbolic_run():
# y = x1+x2
# breakpoint()

breakpoint()
10 changes: 10 additions & 0 deletions experimental/tf_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import tensorflow as tf
import tensorflow.mlir.experimental as tme

print('\n#1')
start=tf.constant(0)
limit=tf.constant(10)
delta=tf.constant(2, dtype=tf.int32)
y = tf.range(start, limit, delta); print(f"{y=}")
f = tme.convert_function((tf.function(tf.range)).get_concrete_function(start, limit, delta))
print(tme.run_pass_pipeline(f, "tf-lower-to-mlprogram-and-hlo"))
48 changes: 42 additions & 6 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,50 @@ def max_impl(self, x, y, *, dim, keepdim):

@backend.set_impl(backend.operator_set.arange)
def arange_impl(self, y, *, start, stop, stride, dtype, device):
ret = ""
if stride == 1 and start == 0:
ret += f"""%{y.name} = "stablehlo.iota"() {{iota_dimension = 0 : i64}} {as_mlir_sig((), y.symval)}"""
else:
ret += f"""%{y.name}_ = "stablehlo.iota"() {{iota_dimension = 0 : i64}} {as_mlir_sig((), y.symval)}"""
return ret

return f"""%{y.name} = "stablehlo.iota"() {{iota_dimension = 0 : i64}} {as_mlir_sig((), y.symval)}"""
normalized = math.ceil(abs(stop-start)/stride)
iota_symval = y.symval.override(shape=tuple(normalized*i if i != 1 else i for i in y.symval.shape))
one_symval = y.symval.override(shape=(1,))
return f"""
%{y.name}_scale_ = stablehlo.constant dense<{stride}> : {as_mlir_shape(one_symval)}
%{y.name}_scale = "stablehlo.broadcast_in_dim"(%{y.name}_scale_) {{
broadcast_dimensions = dense<{repr(list(range(y.symval.ndim)))}>: tensor<{y.symval.ndim}xi64>
}} {as_mlir_sig(( one_symval,), y.symval)}
%{y.name}_shift_ = stablehlo.constant dense<{start}> : {as_mlir_shape(y.symval.override(shape=(1,)))}
%{y.name}_shift = "stablehlo.broadcast_in_dim"(%{y.name}_shift_) {{
broadcast_dimensions = dense<{repr(list(range(y.symval.ndim)))}>: tensor<{y.symval.ndim}xi64>
}} {as_mlir_sig(( one_symval,), y.symval)}
%{y.name}__ = "stablehlo.iota"() {{iota_dimension = 0 : i64}} {as_mlir_sig((), y.symval)}
%{y.name}_ = "stablehlo.multiply"(%{y.name}__, %{y.name}_scale) {as_mlir_sig((y.symval, y.symval), y.symval)}
%{y.name} = "stablehlo.add"(%{y.name}_, %{y.name}_shift) {as_mlir_sig((y.symval, y.symval), y.symval)}
"""

'''
#1
y=<tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 2, 4, 6, 8], dtype=int32)>
module {
func.func @__inference_range_11(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> attributes {allow_soft_placement = false} {
%0 = stablehlo.subtract %arg1, %arg0 : tensor<i32>
%1 = stablehlo.abs %0 : tensor<i32>
%2 = stablehlo.convert %1 : (tensor<i32>) -> tensor<f64>
%3 = stablehlo.convert %arg2 : (tensor<i32>) -> tensor<f64>
%4 = stablehlo.divide %2, %3 : tensor<f64>
%5 = stablehlo.ceil %4 : tensor<f64>
%6 = stablehlo.convert %5 : (tensor<f64>) -> tensor<i64>
%7 = stablehlo.reshape %6 : (tensor<i64>) -> tensor<1xi64>
%8 = stablehlo.dynamic_iota %7, dim = 0 : (tensor<1xi64>) -> tensor<?xi32>
%9 = shape.shape_of %8 : tensor<?xi32> -> tensor<1xindex>
%10 = stablehlo.dynamic_broadcast_in_dim %arg2, %9, dims = [] : (tensor<i32>, tensor<1xindex>) -> tensor<?xi32>
%11 = stablehlo.multiply %8, %10 : tensor<?xi32>
%12 = shape.shape_of %11 : tensor<?xi32> -> tensor<1xindex>
%13 = stablehlo.dynamic_broadcast_in_dim %arg0, %12, dims = [] : (tensor<i32>, tensor<1xindex>) -> tensor<?xi32>
%14 = stablehlo.add %11, %13 : tensor<?xi32>
return %14 : tensor<?xi32>
}
}
'''
@backend.set_impl(backend.operator_set.full)
def full_impl(self, y, *, shape, fill_value, dtype, device):
fill_value = float(fill_value) if "f" in dtype.mlir else int(fill_value)
Expand Down
49 changes: 27 additions & 22 deletions src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def symval(self):

@property
def val(self):
raise RuntimeError(f"this.val should not be accessed, as\n{trace_stack[-1]=}, ")
raise RuntimeError(f"this.val should not be accessed, as\n{trace_stack[-1]=}, ")

@property
def shape(self):
Expand All @@ -386,12 +386,15 @@ def device(self):
return self._device

@classmethod
def like(cls, maybe_tensor):
shape = maybe_tensor.shape
dtype = maybe_tensor.dtype
device = maybe_tensor.device
def like(cls, maybe_tensor, **but):
shape = but.get("shape", maybe_tensor.shape)
dtype = but.get("dtype", maybe_tensor.dtype)
device = but.get("device",maybe_tensor.device)
return cls(shape, dtype, device)

def override(self, **but):
return self.like(self, **but)

def str_short(self):
return f'{str(self.dtype)}[{",".join(str(d) for d in self.shape)}]'

Expand All @@ -407,14 +410,16 @@ def __repr__(self):
return f"<SymbolicTensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device}>"

# def __getattr__(self, attr):
# if attr in vars(backend.operator_set).keys():
# op = getattr(backend.operator_set, attr)
# elif attr in vars(backend.procedure_set).keys():
# procedure = getattr(backend.procedure_set, attr)
# assert not isinstance(procedure, classmethod), f"use {attr} instead of self.{attr}"
# return partial(procedure, self)
# else:
# return self.__getattribute__(attr)
# with symbolic_run():
# if attr in vars(backend.operator_set).keys():
# op = getattr(backend.operator_set, attr)
# return partial(op, self)
# elif attr in vars(backend.procedure_set).keys():
# procedure = getattr(backend.procedure_set, attr)
# assert not isinstance(procedure, classmethod), f"use {attr} instead of self.{attr}"
# return partial(procedure, self)
# else:
# return self.__getattribute__(attr)


# ================
Expand Down Expand Up @@ -532,7 +537,7 @@ def vmap(self, x, *, dim_size, vals_in, dims_in, **params):
return [self(x, **params)], [x_bdim]

def typecheck(self, x, **params):
return [SymbolicTensor.like(x)]
return [x.override()]

def jvp(self, primals, tangents, **params):
(x,), (x_dot,) = primals, tangents
Expand Down Expand Up @@ -592,18 +597,18 @@ def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[Symb
return [symx]
shape_delta = len(symx.shape) - len(symy.shape)
if shape_delta > 0:
symy = SymbolicTensor((1,) * shape_delta + symy.shape, symy.dtype)
symy = symy.override(shape=(1,) * shape_delta + symy.shape)
elif shape_delta < 0:
x = x.reshape((1,) * -shape_delta + symx.shape)
symx = SymbolicTensor((1,) * -shape_delta + symx.shape, symx.dtype)
symx = symx.override(shape=(1,) * -shape_delta + symx.shape)
if symx == symy:
return [symx]
else:
shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)])
if symx.shape != shape_ret:
symx = SymbolicTensor(shape_ret, symx.dtype)
symx = symx.override(shape=shape_ret)
if symy.shape != shape_ret:
symy = SymbolicTensor(shape_ret, symy.dtype)
symy = symx.override(shape=shape_ret)
if symx != symy:
raise TypeError
return [symx]
Expand Down Expand Up @@ -635,7 +640,7 @@ def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[Symbo
new_shape = [d if i not in dim_ else 1 for i, d in enumerate(x.shape)]
else:
new_shape = [d for i, d in enumerate(x.shape) if i not in dim_]
return [SymbolicTensor(tuple(new_shape), x.dtype, x.device)]
return [x.override(shape=tuple(new_shape))]


class InitOperator(Operator):
Expand Down Expand Up @@ -1391,7 +1396,7 @@ def symval(self):
else:
shape = list(symval.shape)
del shape[self.vmap_dim]
return SymbolicTensor(tuple(shape), symval.dtype, symval.device)
return symval.override(shape=tuple(shape))

def full_lower(self):
if self.vmap_dim is None:
Expand Down Expand Up @@ -2052,6 +2057,7 @@ def stash_trace(main: MainTrace):

@contextmanager
def symbolic_run():
global trace_stack
level = len(trace_stack)
main = MainTrace(level, SymbolicRunTrace, global_data=None)
trace_stack += [main]
Expand Down Expand Up @@ -2087,7 +2093,7 @@ def unmapped_symval(axis_size: int, batch_dim, symval: SymbolicTensor) -> Symbol
else:
shape = list(symval.shape)
shape.insert(batch_dim, axis_size)
return SymbolicTensor(tuple(shape), symval.dtype, symval.device)
return symval.override(shape=tuple(shape))

vmap_traceable = vmap(program_as_fun(program), tuple(dims_in))
in_symvals = [unmapped_symval(dim_size, d, v.symval) for v, d in zip(program.in_binders, dims_in)]
Expand Down Expand Up @@ -2581,7 +2587,6 @@ def get_program(self, *args, **static_args):
args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)])

symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args)
# self.name = self.get_jit_name(symvals_in, static_args, prefix=self.name, short=True)
static_args = tuple(static_args.items())
if self.name is None:
self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}"
Expand Down
12 changes: 11 additions & 1 deletion src/slope/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,20 @@ def args_fixer(self, *, start, stop=None, stride=None, dtype=None, device=None):
dtype = dtypes.int32
if device is None:
device = slope.core.backend.DEFAULT_DEVICE
if 'f' in dtype.mlir:
start, stop, stride = float(start), float(stop), float(stride)
elif 'i' in dtype.mlir:
start, stop, stride = int(start), int(stop), int(stride)
return (), dict(start=start, stop=stop, stride=stride, dtype=dtype, device=device)

def typecheck(self, *, start, stop, stride, dtype, device) -> List[SymbolicTensor]:
return [SymbolicTensor((((stop - start) * stride),), dtype, device)]

assert stride != 0
if stride > 0:
assert stop > start
else:
assert stop < start
return [SymbolicTensor((int(math.ceil((abs(stop - start) / abs(stride)))),), dtype, device)]


# -------------------
Expand Down
19 changes: 8 additions & 11 deletions src/slope/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,16 @@ def min(x, dim=None, keepdim=False):
@procedure_set.register()
def argmax(x, dim=None, keepdim=False):
if dim is None:
idx = (x == x.max(dim)) * slope.arange(
math.prod(x.shape) - 1,
-1,
-1,
dtype=slope.int32,
).reshape(x.shape)
return math.prod(x.shape) - idx.max() - 1
ar = slope.arange(math.prod(x.shape) - 1, -1, -1, dtype=x.dtype)
ar = ar.reshape(x.shape)
idx = (x == x.max(dim)).cast(x.dtype) * ar
return math.prod(x.shape) - idx.max().cast(slope.int32) - 1
dim = dim + len(x.shape) if dim < 0 else dim
m = (x == x.max(dim=dim, keepdim=True)).cast(slope.int32)
idx = m * slope.arange(x.shape[dim] - 1, -1, -1, dtype=slope.int32)
idx = idx.reshape(x.shape[dim], *[1] * (x.ndim - dim - 1))
m = (x == x.max(dim=dim, keepdim=True)).cast(x.dtype)
idx = m * slope.arange(x.shape[dim] - 1, -1, -1, dtype=x.dtype)
# idx = idx.reshape(x.shape[dim], *[1] * (x.ndim - dim - 1))
ret = x.shape[dim] - idx.max(dim=dim, keepdim=keepdim) - 1
return ret
return ret.cast(slope.int32)


@procedure_set.register()
Expand Down

0 comments on commit 82f79f4

Please sign in to comment.