Skip to content

Commit

Permalink
gather_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 21, 2024
1 parent feffe1c commit 2e5486e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 65 deletions.
2 changes: 0 additions & 2 deletions examples/nn/mnist_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from lib.datasets.mnist import get_mnist

import numpy as np


def loss_fn(model, batch):
x, y_onehot = batch
preds = model(x)
Expand Down
16 changes: 16 additions & 0 deletions examples/simple/symbolic_maths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import slope

x1 = slope.tensor((1,))
x2 = slope.tensor((2,))
y1 = x1+x2
with slope.symbolic_run():
sym_y = x1+x2
y2 = x1+x2
print(f"{y1=}")
print(f"{sym_y=}")
print(f"{y2=}")
# x1 = slope.symbolic_tensor((1,))
# x2 = slope.symbolic_tensor((1,))
# with slope.symbolic_run():
# y = x1+x2
# breakpoint()
108 changes: 75 additions & 33 deletions src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def dblog(*msg, enable=True):
def unzip2(pairs) -> Tuple[List[Any], List[Any]]:
lst1, lst2 = [], []
for i1, i2 in pairs:
lst1.append(i1)
lst2.append(i2)
lst1 += [i1]
lst2 += [i2]
return lst1, lst2


Expand Down Expand Up @@ -88,7 +88,7 @@ def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
lst2: List[Any] = []
lists = lst1, lst2
for b, x in list_zip(bs, l):
lists[b].append(x)
lists[b] += [x]
return lst1, lst2


Expand Down Expand Up @@ -233,6 +233,10 @@ def __init__(self, val: TensorBuffer):
assert isinstance(val, TensorBuffer)
self.buf = val

@property
def symval(self):
return SymbolicTensor.like(self)

@property
def default_dtype(self):
return backend.default_dtype
Expand Down Expand Up @@ -350,7 +354,7 @@ def nbytes(self):

def __repr__(self):
return (
f"{self.numpy()}\n<Tensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device.format_code}>"
f"<Tensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device.format_code}, val=\n{self.numpy()}\n>"
)


Expand All @@ -360,6 +364,14 @@ def __init__(self, shape, dtype, device):
self._shape = tuple(int(i) for i in shape)
self._dtype = dtype
self._device = device

@property
def symval(self):
return self

@property
def val(self):
raise RuntimeError(f"this.val should not be accessed, as\n{trace_stack[-1]=}, ")

@property
def shape(self):
Expand Down Expand Up @@ -765,6 +777,16 @@ def tensor(
val = np.frombuffer(val, dtype=dtype)
return self.from_numpy(val, dtype, device)

def symbolic_tensor(
self,
shape: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
dtype: Optional[Any] = None,
device=None,
):
dtype = dtype or self.DEFAULT_DTYPE
device = device or self.DEFAULT_DEVICE
return SymbolicTensor(shape, dtype, device)

def seed(self, seed):
raise NotImplementedError

Expand Down Expand Up @@ -1307,6 +1329,15 @@ def fn(*args, **params):

return fn

class SymbolicRunTrace(Trace):
# pure = lambda self, x: x
def pure(self, val: Any) -> SymbolicTensor:
return val.symval

def run_op(self, op, tracers, params):
symvals_in = tree_map(lambda x: x.symval, tracers)
symvals_out = op.typecheck(*symvals_in, **params)
return symvals_out

class TraceTensor(Tensor):
PYTHON_TYPES = {
Expand Down Expand Up @@ -1443,9 +1474,6 @@ class ProgramTrace(Trace):
def builder(self):
return self.main.global_data

def __init__(self, main: MainTrace) -> None:
self.main = main

def new_arg(self, symval) -> ProgramTraceTensor:
symval = SymbolicTensor.like(symval)
tracer = self.builder.new_tracer(self, symval)
Expand All @@ -1462,7 +1490,6 @@ def pure(self, val: Any) -> ProgramTraceTensor:
return tracer

def run_op(self, op, tracers, params):
symvals_in = [t.symval for t in tracers]
symvals_in = tree_map(lambda x: x.symval, tracers)
symvals_out = op.typecheck(*symvals_in, **params)

Expand Down Expand Up @@ -1490,11 +1517,11 @@ def __init__(self):

def new_tracer(self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor:
tracer = ProgramTraceTensor(trace, symval)
self.tracers.append(tracer)
self.tracers += [tracer]
return tracer

def add_instruction(self, instruction: Instruction) -> None:
self.instructions.append(instruction)
self.instructions += [instruction]

def add_var(self, tracer: ProgramTraceTensor) -> Var:
assert id(tracer) not in self.tracer_to_var
Expand Down Expand Up @@ -1665,7 +1692,7 @@ def run_op(self, op, tracers, params):


trace_stack: List[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None
stashed_trace: Optional[MainTrace] = None
trace_stack += [MainTrace(0, RunTrace, None)]


Expand Down Expand Up @@ -1799,25 +1826,18 @@ def tree_map(f: Callable[..., Any], tree, *rest, out_leaf=False) -> Any:


@contextmanager
def new_main(trace_type: Type["Trace"], global_data=None):
def new_main_trace(trace_type: Type["Trace"], global_data=None):
global trace_stack
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
trace_stack += [main]

try:
yield main
finally:
trace_stack.pop()


@contextmanager
def new_dynamic(main: MainTrace):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
try:
yield
finally:
dynamic_trace = prev_dynamic_trace


def bind(op, *args, **params):
Expand Down Expand Up @@ -1846,8 +1866,8 @@ def get_arr_from_seq(seq):
default=trace_stack[0],
key=operator_py.attrgetter("level"),
)
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
if stashed_trace and stashed_trace.level > top_main.level:
top_main = stashed_trace
return top_main.trace_type(top_main)


Expand Down Expand Up @@ -1936,7 +1956,7 @@ def vmap_flat(f, in_dim, out_dim, dim_size, *args):
dims = set([x.shape[d] for x, d in list_zip(args, in_dim) if d is not None])
assert len(dims) == 1
(dim_size,) = dims
with new_main(VMapTrace, dim_size) as main:
with new_main_trace(VMapTrace, dim_size) as main:
trace = VMapTrace(main)
tracers_in = [VMapTraceTensor(trace, x, dim) if dim is not None else x for x, dim in list_zip(args, in_dim)]
outs = f(*tracers_in)
Expand Down Expand Up @@ -1969,7 +1989,7 @@ def batched_f(*args):


def jvp_flat(f, primals, tangents, *, has_aux, global_data, **static_args):
with new_main(JVPTrace, global_data) as main:
with new_main_trace(JVPTrace, global_data) as main:
trace = JVPTrace(main)
tracers_in = [JVPTraceTensor(trace, x, t) for x, t in list_zip(primals, tangents)]
jvp_flat_ret = f(*tracers_in, **static_args)
Expand Down Expand Up @@ -2019,18 +2039,40 @@ def jacfwd(f, x):
return vmap(pushfwd, (0,))(vecs_in)



@contextmanager
def stash_trace(main: MainTrace):
global stashed_trace
prev_stashed_trace, stashed_trace = stashed_trace, main
try:
yield
finally:
stashed_trace = prev_stashed_trace


@contextmanager
def symbolic_run():
level = len(trace_stack)
main = MainTrace(level, SymbolicRunTrace, global_data=None)
trace_stack += [main]
global stashed_trace
prev_stashed_trace, stashed_trace = stashed_trace, main
try:
yield
finally:
stashed_trace = prev_stashed_trace
trace_stack.pop()

@lru_cache_verbose()
def make_program(f: Callable, *symvals_in: SymbolicTensor, static_args, name) -> Tuple[Program, List[Any], TreeDef]:
symvals_in, in_tree = tree_flatten(symvals_in)
f, out_tree_store = flatten_fn(f, in_tree)
builder = ProgramBuilder()
with new_main(ProgramTrace, builder) as main:
with new_dynamic(main):
with new_main_trace(ProgramTrace, builder) as main:
with stash_trace(main):
trace = ProgramTrace(main)
tracers_in = [trace.new_arg(symval) for symval in symvals_in]
outs = f(*tracers_in, **{k: v for k, v in static_args})
# tracers_out = [full_raise(trace, out) for out in outs]
# raise check because of aux is not ProgramTraceTensor
tracers_out = [full_raise(trace, out) if isinstance(out, ProgramTraceTensor) else out.val for out in outs]
program, consts = builder.build(tracers_in, tracers_out, static_args, name)

Expand Down Expand Up @@ -2079,7 +2121,7 @@ def jvp_traceable(*primals_and_tangents):
def partial_run_flat(
f: Callable, pvals_in: List["PartialValue"], has_aux, global_data=None
) -> Tuple[Program, List["PartialValue"], List[Any]]:
with new_main(PartialRunTrace, global_data) as main:
with new_main_trace(PartialRunTrace, global_data) as main:
trace = PartialRunTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
Expand Down Expand Up @@ -2273,7 +2315,7 @@ def draft_to_instruction(tracer_to_var: Dict[int, Var], draft: InstructionDraft)
tracer_to_var[id(t)] = var
elif isinstance(t.draft, InstructionDraft):
if id(t.draft) not in processed_instructions:
instructions.append(draft_to_instruction(tracer_to_var, t.draft))
instructions += [draft_to_instruction(tracer_to_var, t.draft)]
processed_instructions.add(id(t.draft))
else:
raise TypeError(t.draft)
Expand Down Expand Up @@ -2317,10 +2359,10 @@ def remove_duplicates(lst):
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
sorted_nodes += [node]
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
childless_nodes += [parent]
else:
child_counts[id(parent)] -= 1

Expand Down
31 changes: 1 addition & 30 deletions src/slope/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def normalize_int(e, i, dim_sz):
ret = sliced_tensor.reshape(tuple(final_shape))

if tensors: # Fancy/tensor indexing
# return x.gather_nd(val)
# normalize idx
idx = [t.sign().neg().relu() * ret.shape[d] + t for d, t in zip(dim, tensors)]
max_dim = max(i.ndim for i in idx)
Expand Down Expand Up @@ -370,42 +369,14 @@ def flatten_seq(l: Iterator):
return x


# @procedure_set.register()
# def padslice(x, pads: Sequence[Optional[Tuple[int, int]]], value: float = 0):
# pads = tuple(a if a is not None else (0, s) for s, a in zip(x.shape, pads))
# pads1 = tuple((max(0, -p[0]), max(0, p[1] - x.shape[i])) for i, p in enumerate(pads))
# pads2 = tuple(item for sublist in pads1 for item in sublist)[::-1]
# x = x.pad(pads1, value=value) # flatten

# starts, limits, strides = tuple(
# zip(*[(p[0] + p1[0], p[1] + p1[0], 1) for (p, p1) in zip(pads, pads1)])
# )
# x = x.slice(starts, limits, strides)
# return x
# # starts, limits, strides = tuple(
# # zip(*[(p[0] + p_[i][0], p[1] + p_[i][0], 1) for i, p in enumerate(arg)])
# # )


@procedure_set.register()
def pad2d(x, padding: Union[List[int], Tuple[int, ...]], value: float = 0):
# (padding_left, padding_right, padding_top, padding_bottom)
slc = [(-p0, s + p1) for p0, p1, s in zip(padding[::2], padding[1::2], x.shape[::-1])][::-1]
return x.padslice([(0, s) for s in x.shape[: -(len(padding) // 2)]] + slc, value=value)


@procedure_set.register()
def gather(x, dim, idx):
if dim != 0:
x, idx = x.transpose(0, idx), idx.transpose(0, idx)
ret = x.gather_nd(idx, batch_dims=0)
if dim != 0:
ret = ret.transpose(0, idx)
return ret


@procedure_set.register()
def gather_arange(x, idx, dim: int):
def gather(x, idx, dim: int):
assert idx.ndim == x.ndim, "x.ndim must equal idx.ndim"
assert all(s >= i for s, i in zip(x.shape, idx.shape)), "all dim of idx.shape must be smaller than x.shape"
if dim < 0:
Expand Down

0 comments on commit 2e5486e

Please sign in to comment.