Skip to content

Commit

Permalink
simplify scan further
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Sep 11, 2024
1 parent f088a06 commit 7c804a1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 125 deletions.
8 changes: 4 additions & 4 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,12 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"):
if len(v) == 1:
jvpfun = lambda dx: Derivative.compute_jvp(fun, 0, dx, x)
return batched_vectorize(
jvpfun, signature="(n)->(k)", jac_chunk_size=self._jac_chunk_size
jvpfun, signature="(n)->(k)", chunk_size=self._jac_chunk_size
)(v[0])
elif len(v) == 2:
jvpfun = lambda dx1, dx2: Derivative.compute_jvp2(fun, 0, 0, dx1, dx2, x)
return batched_vectorize(
jvpfun, signature="(n),(n)->(k)", jac_chunk_size=self._jac_chunk_size
jvpfun, signature="(n),(n)->(k)", chunk_size=self._jac_chunk_size
)(v[0], v[1])
elif len(v) == 3:
jvpfun = lambda dx1, dx2, dx3: Derivative.compute_jvp3(
Expand All @@ -573,7 +573,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"):
return batched_vectorize(
jvpfun,
signature="(n),(n),(n)->(k)",
jac_chunk_size=self._jac_chunk_size,
chunk_size=self._jac_chunk_size,
)(v[0], v[1], v[2])
else:
raise NotImplementedError("Cannot compute JVP higher than 3rd order.")
Expand Down Expand Up @@ -1164,7 +1164,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"):
jvpfun = lambda *dx: Derivative.compute_jvp(fun, tuple(range(len(x))), dx, *x)
sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)"
return batched_vectorize(
jvpfun, signature=sig, jac_chunk_size=self._jac_chunk_size
jvpfun, signature=sig, chunk_size=self._jac_chunk_size
)(*v)

@jit
Expand Down
157 changes: 36 additions & 121 deletions desc/utils_batched_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,29 @@ def _unchunk(x):


@_treeify
def _chunk(x, jac_chunk_size=None):
# jac_chunk_size=None -> add just a dummy chunk dimension,
def _chunk(x, chunk_size=None):
# chunk_size=None -> add just a dummy chunk dimension,
# same as np.expand_dims(x, 0)
if x.ndim == 0:
raise ValueError("x cannot be chunked as it has 0 dimensions.")
n = x.shape[0]
if jac_chunk_size is None:
jac_chunk_size = n
if chunk_size is None:
chunk_size = n

n_chunks, residual = divmod(n, jac_chunk_size)
n_chunks, residual = divmod(n, chunk_size)
if residual != 0:
raise ValueError(
"The first dimension of x must be divisible by jac_chunk_size."
+ f"\n Got x.shape={x.shape} but jac_chunk_size={jac_chunk_size}."
"The first dimension of x must be divisible by chunk_size."
+ f"\n Got x.shape={x.shape} but chunk_size={chunk_size}."
)
return x.reshape((n_chunks, jac_chunk_size) + x.shape[1:])
return x.reshape((n_chunks, chunk_size) + x.shape[1:])


def _jac_chunk_size(x):
def _chunk_size(x):
b = set(map(lambda x: x.shape[:2], jax.tree_util.tree_leaves(x)))
if len(b) != 1:
raise ValueError(
"The arrays in x have inconsistent jac_chunk_size or number of chunks"
"The arrays in x have inconsistent chunk_size or number of chunks"
)
return b.pop()[1]

Expand All @@ -80,27 +80,27 @@ def unchunk(x_chunked):
"""
return _unchunk(x_chunked), functools.partial(
_chunk, jac_chunk_size=_jac_chunk_size(x_chunked)
_chunk, chunk_size=_chunk_size(x_chunked)
)


def chunk(x, jac_chunk_size=None):
def chunk(x, chunk_size=None):
"""Split an array (or a pytree of arrays) into chunks along the first axis.
Parameters
----------
x: an array (or pytree of arrays)
jac_chunk_size: an integer or None (default)
The first axis in x must be a multiple of jac_chunk_size
chunk_size: an integer or None (default)
The first axis in x must be a multiple of chunk_size
Returns
-------
(x_chunked, unchunk_fn): tuple
- x_chunked is x reshaped to (-1, jac_chunk_size)+x.shape[1:]
if jac_chunk_size is None then it defaults to x.shape[0], i.e. just one chunk
- x_chunked is x reshaped to (-1, chunk_size)+x.shape[1:]
if chunk_size is None then it defaults to x.shape[0], i.e. just one chunk
- unchunk_fn is a function which restores x given x_chunked
"""
return _chunk(x, jac_chunk_size), _unchunk
return _chunk(x, chunk_size), _unchunk


####
Expand All @@ -113,115 +113,32 @@ def chunk(x, jac_chunk_size=None):
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");

_tree_add = functools.partial(jax.tree_util.tree_map, jax.lax.add)
_tree_zeros_like = functools.partial(
jax.tree_util.tree_map, lambda x: jnp.zeros(x.shape, dtype=x.dtype)
)


# TODO put it somewhere
def _multimap(f, *args):
try:
return tuple(map(lambda a: f(*a), zip(*args)))
except TypeError:
return f(*args)


def scan_append(f, x):
"""Evaluate f element by element in x while appending and/or reducing the results.
"""Evaluate f element by element in x while appending the results.
Parameters
----------
f: a function that takes elements of the leading dimension of x
x: a pytree where each leaf array has the same leading dimension
append_cond: a bool (if f returns just one result) or a tuple of
bools (if f returns multiple values)
which indicates whether the individual result should
be appended or reduced
op: a function to (pairwise) reduce the specified results. Defaults to a sum.
zero_fun: a function which prepares the zero element of op for a given input
shape/dtype tree. Defaults to zeros.
Returns
-------
The (tuple of) results corresponding to the output of f
where each result is given by:
if append_cond is True:
a (pytree of) array(s) with leading dimension same as x,
containing the evaluation of f at each element in x
else (append_cond is False):
a (pytree of) array(s) with the same shape as the corresponding
output of f, containing the reduction over op of f evaluated at each x
Example:
import jax.numpy as jnp
from netket.jax import scan_append_reduce
def f(x):
y = jnp.sin(x)
return y, y, y**2
N = 100
x = jnp.linspace(0.,jnp.pi,N)
y, s, s2 = scan_append_reduce(f, x, (True, False, False))
mean = s/N
var = s2/N - mean**2
a (pytree of) array(s) with leading dimension same as x,
containing the evaluation of f at each element in x
"""
# TODO: different op for each result

x0 = jax.tree_util.tree_map(lambda x: x[0], x)

# special code path if there is only one element
# to avoid having to rely on xla/llvm to optimize the overhead away
if jax.tree_util.tree_leaves(x)[0].shape[0] == 1:
return _multimap(lambda c, x: jnp.expand_dims(x, 0) if c else x, True, f(x0))

# the original idea was to use pytrees,
# however for now just operate on the return value tuple
_get_append_part = functools.partial(_multimap, lambda c, x: x if c else None, True)

carry_init = True

def f_(carry, x):
y = f(x)
y_append = _get_append_part(y)
return False, y_append
return False, f(x)

_, res_append = jax.lax.scan(f_, carry_init, x, unroll=1)
# reconstruct the result from the reduced and appended parts in the two trees
return res_append # _tree_select(res_append, res_op)
return res_append


# TODO in_axes a la vmap?
def _scanmap(fun, scan_fun, argnums=0):
"""A helper function to wrap f with a scan_fun.
Example
-------
import jax.numpy as jnp
from functools import partial
from desc.utils import _scanmap, scan_append_reduce
scan_fun = partial(scan_append_reduce, append_cond=(True, False, False))
@partial(_scanmap, scan_fun=scan_fun, argnums=1)
def f(c, x):
y = jnp.sin(x) + c
return y, y, y**2
N = 100
x = jnp.linspace(0.,jnp.pi,N)
c = 1.
y, s, s2 = f(c, x)
mean = s/N
var = s2/N - mean**2
"""
"""A helper function to wrap f with a scan_fun."""

def f_(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
Expand All @@ -242,19 +159,19 @@ def f_(*args, **kwargs):
# Licensed under the Apache License, Version 2.0 (the "License");


def _eval_fun_in_chunks(vmapped_fun, jac_chunk_size, argnums, *args, **kwargs):
def _eval_fun_in_chunks(vmapped_fun, chunk_size, argnums, *args, **kwargs):
n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0]
n_chunks, n_rest = divmod(n_elements, jac_chunk_size)
n_chunks, n_rest = divmod(n_elements, chunk_size)

if n_chunks == 0 or jac_chunk_size >= n_elements:
if n_chunks == 0 or chunk_size >= n_elements:
y = vmapped_fun(*args, **kwargs)
else:
# split inputs
def _get_chunks(x):
x_chunks = jax.tree_util.tree_map(
lambda x_: x_[: n_elements - n_rest, ...], x
)
x_chunks = _chunk(x_chunks, jac_chunk_size)
x_chunks = _chunk(x_chunks, chunk_size)
return x_chunks

def _get_rest(x):
Expand Down Expand Up @@ -284,16 +201,16 @@ def _get_rest(x):

def _chunk_vmapped_function(
vmapped_fun: Callable,
jac_chunk_size: Optional[int],
chunk_size: Optional[int],
argnums=0,
) -> Callable:
"""Takes a vmapped function and computes it in chunks."""
if jac_chunk_size is None:
if chunk_size is None:
return vmapped_fun

if isinstance(argnums, int):
argnums = (argnums,)
return functools.partial(_eval_fun_in_chunks, vmapped_fun, jac_chunk_size, argnums)
return functools.partial(_eval_fun_in_chunks, vmapped_fun, chunk_size, argnums)


def _parse_in_axes(in_axes):
Expand All @@ -313,15 +230,15 @@ def vmap_chunked(
f: Callable,
in_axes=0,
*,
jac_chunk_size: Optional[int],
chunk_size: Optional[int],
) -> Callable:
"""Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.
Parameters
----------
f: The function to be vectorised.
in_axes: The axes that should be scanned along. Only supports `0` or `None`
jac_chunk_size: The maximum size of the chunks to be used. If it is `None`,
chunk_size: The maximum size of the chunks to be used. If it is `None`,
chunking is disabled
Expand All @@ -331,12 +248,10 @@ def vmap_chunked(
"""
in_axes, argnums = _parse_in_axes(in_axes)
vmapped_fun = jax.vmap(f, in_axes=in_axes)
return _chunk_vmapped_function(vmapped_fun, jac_chunk_size, argnums)
return _chunk_vmapped_function(vmapped_fun, chunk_size, argnums)


def batched_vectorize(
pyfunc, *, excluded=frozenset(), signature=None, jac_chunk_size=None
):
def batched_vectorize(pyfunc, *, excluded=frozenset(), signature=None, chunk_size=None):
"""Define a vectorized function with broadcasting and batching.
below is taken from JAX
Expand Down Expand Up @@ -366,7 +281,7 @@ def batched_vectorize(
provided, ``pyfunc`` will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.
jac_chunk_size: the size of the batches to pass to vmap. if 1, will only
chunk_size: the size of the batches to pass to vmap. if 1, will only
Returns
-------
Expand Down Expand Up @@ -448,7 +363,7 @@ def wrapped(*args, **kwargs):
else:
# change the vmap here to chunked_vmap
vectorized_func = vmap_chunked(
vectorized_func, in_axes, jac_chunk_size=jac_chunk_size
vectorized_func, in_axes, chunk_size=chunk_size
)
result = vectorized_func(*squeezed_args)

Expand Down

0 comments on commit 7c804a1

Please sign in to comment.