Skip to content

Commit

Permalink
Merge pull request #19819 from mattjj:scan-dont-traverse-body-jaxpr-i…
Browse files Browse the repository at this point in the history
…n-lowering

PiperOrigin-RevId: 607570860
  • Loading branch information
jax authors committed Feb 16, 2024
2 parents ff3247e + 5ead7a6 commit 330afdc
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 146 deletions.
4 changes: 3 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2381,6 +2381,8 @@ def get_bind_params(self, params):

closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)
closed_call_p.def_effectful_abstract_eval(
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))


outfeed_primitives: set[Primitive] = set()
Expand Down Expand Up @@ -2788,7 +2790,7 @@ class JaxprTypeError(TypeError): pass

def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
if list(in_avals) != list(call_jaxpr.in_avals):
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, call_jaxpr.effects
custom_typechecks[closed_call_p] = _check_closed_call
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,10 @@ def aval_to_types(aval):
args.append([hlo.create_token()])
else:
args.append(arg)
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
if name is not None:
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
else:
callee_name_stack = ctx.name_stack
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in,
Expand Down Expand Up @@ -1862,7 +1865,7 @@ def core_call_lowering(ctx: LoweringRuleContext,

register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(core_call_lowering, name="core_closed_call"))
partial(core_call_lowering, name=None))

def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:
Expand Down
177 changes: 36 additions & 141 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,142 +361,39 @@ def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
'the shapes do not match' * shape_mismatch)
return ''


def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])

carry = init
ys = []

for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)

ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)

def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])

def cond_fun(vals):
i, *_ = vals
return i < length

def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
# because the scan body may consume any keys within it.
# Import here to avoid circular imports
from jax.experimental import key_reuse
xs_unconsumed = _map(key_reuse.reuse_key, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out

# TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them.

ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs

def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])

num_blocks, rem = divmod(length, block_length)
assert rem == 0

partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)

prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)

f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)

outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)

carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)

def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll):
from jax.experimental import key_reuse # TODO(jakevdp)[key-reuse]
consts, carry, xs = split_list(args, [num_consts, num_carry])
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
f_impl = core.jaxpr_as_fun(jaxpr)

if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)

consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll

if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
else:
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))

outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)

carry, ys = split_list(outs, [num_carry])

if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)

return (*carry, *ys)

def _stack(aval, vals):
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)

def _concatenate(aval, x1, x2):
return lax.concatenate([x1, x2], 0)
ys = _map(partial(_empty_array, length), y_avals)
num_trips, remainder = divmod(length, unroll)

def cond_fun(while_carry):
i, _, _ = while_carry
return i < num_trips * unroll
def body_fun(while_carry):
i, carry, ys = while_carry
for _ in range(unroll):
i, carry, ys = _step(i, carry, ys)
return i, carry, ys
def _step(i, carry, ys):
i_ = length - i - 1 if reverse else i
# TODO(jakevdp)[key-reuse]: logic not right, scan may consume keys within
xs_unconsumed = _map(key_reuse.reuse_key, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out = eval_jaxpr_p.bind(*consts, *carry, *x, jaxpr=jaxpr)
carry, y = split_list(out, [num_carry])
ys = _map(partial(_update_array, i_), y_avals, ys, y)
return i + 1, carry, ys

def _split_leading_dim(i, aval, x):
assert x.ndim >= 1
return (slicing.slice_in_dim(x, 0, i),
slicing.slice_in_dim(x, i, x.shape[0]))
i = lax._const(length, 0)
if num_trips:
i, carry, ys = jax.lax.while_loop(cond_fun, body_fun, (i, carry, ys))
for _ in range(remainder):
i, carry, ys = _step(i, carry, ys)
return [*carry, *ys]

def _dynamic_index_array(i, aval, x):
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
Expand All @@ -510,16 +407,14 @@ def _empty_array(sz, aval):
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)

def _partition_leading(sz0, sz1, aval, x):
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))

def _combine_leading(sz0, sz1, aval, x):
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)
eval_jaxpr_p = core.Primitive('eval_jaxpr')
eval_jaxpr_p.multiple_results = True
def _stage_jaxpr(trace, *tracers, jaxpr):
params = dict(call_jaxpr=jaxpr)
return trace.default_process_primitive(core.closed_call_p, tracers, params)
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf
def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects

def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
Expand Down
6 changes: 4 additions & 2 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,11 +1464,13 @@ def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]:
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"


# Call primitives are inlined
for unexpected in [core.call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)

tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \
lambda *args, jaxpr: _interpret_jaxpr(
jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None)

# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"clz",
Expand Down

0 comments on commit 330afdc

Please sign in to comment.