Skip to content

Commit

Permalink
Merge pull request #19532 from mattjj:jax-attrs2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 602079647
  • Loading branch information
jax authors committed Jan 28, 2024
2 parents d4660a0 + 22160df commit 8bbcbb6
Show file tree
Hide file tree
Showing 36 changed files with 365 additions and 165 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ py_library_providing_imports_info(
"_src/internal_test_util/**",
],
) + [
"experimental/attrs.py",
# until new parallelism APIs are moved out of experimental
"experimental/maps.py",
"experimental/pjit.py",
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
try:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' not in msg:
Expand Down Expand Up @@ -620,7 +620,7 @@ def transposed(*args_flat):
in_cts_nz, _ = partition_list(in_zeros, in_cts)
return in_cts_nz

transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def computation_maker(*args, **kwargs):
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
ordered_effects = list(
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def jaxpr_to_checkify_jaxpr(
fun = lu.wrap_init(checkify_jaxpr_partial)
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)

new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
out_tree, error_effects = metadata()
return checked_jaxpr, out_tree, error_effects
Expand Down Expand Up @@ -832,7 +832,7 @@ def new_body_f(*c_consts_and_vals):
return out
new_body_f_ = lu.wrap_init(new_body_f)
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
*body_jaxpr.in_avals])
closed_jaxpr = pe.close_jaxpr(jaxpr)
err_vals, err_tree = jtu.tree_flatten(error)
Expand Down Expand Up @@ -1128,7 +1128,7 @@ def checked_fun(*args, **kwargs):
# stage:
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __call__(self, *args, **kwargs):
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _resolve_kwargs(fun, args, kwargs):
return ba.args

def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr, consts

def _close_jaxpr(jaxpr):
Expand Down Expand Up @@ -977,7 +977,7 @@ def fwd(*args, **kwargs):
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)

def bwd(res, cts):
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def _maybe_perturbed(x: Any) -> bool:
@cache()
def _closure_convert_for_avals(fun, in_tree, in_avals):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()

(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()

@lu.transformation_with_aux
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def _batch_jaxpr2(
avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval)
if b is not not_mapped else aval
for aval, b in unsafe_zip(avals_in, in_axes2)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2)
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()

def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis,
Expand Down Expand Up @@ -834,7 +834,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
main_type)
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()

@lu.transformation_with_aux
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ def f_lowered(ctx, *args, **params):
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?

out, tokens = jaxpr_subcomp(
Expand Down
68 changes: 43 additions & 25 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def _closed_call_param_updater(params, _, __):
call_param_updaters[core.closed_call_p] = _closed_call_param_updater

def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
Expand Down Expand Up @@ -1113,7 +1113,7 @@ def fun(*known_vals_in):
return [*known_vals_out, *residuals]

known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking

# check jaxpr_known and jaxpr_unknown in isolation
Expand Down Expand Up @@ -1754,6 +1754,9 @@ class JaxprStackFrame:
eqns: list[JaxprEqn]
invars: list[Var]
effects: core.Effects
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: DebugInfo | None

def __init__(self):
Expand All @@ -1765,23 +1768,29 @@ def __init__(self):
self.eqns = [] # cleared when we pop frame from main
self.invars = []
self.effects = set()
self.attrs_tracked = []
self.attrs_inits = []
self.attrs_vars = []
self.debug_info = None

def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)

def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> tuple[Jaxpr, list[Any]]:
def to_jaxpr(self, out_tracers: Sequence[Tracer]
) -> tuple[Jaxpr, list[Any], list[tuple[Any, str]]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvals: Sequence[Any]
invars = self.attrs_vars + self.invars
state_outvars = [self.tracer_to_var[id(t)] for t in get_states(self.attrs_tracked)]
explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
outvars = state_outvars + explicit_outvars
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, list(constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
set_states(self.attrs_tracked, self.attrs_inits)
return jaxpr, list(constvals), self.attrs_tracked

def to_jaxpr2(self, out_tracers):
# It's not necessary, but we keep the tracer-to-var mapping injective:
Expand Down Expand Up @@ -2064,7 +2073,7 @@ def process_map(self, map_primitive, f, tracers, params):
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals,
debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
Expand Down Expand Up @@ -2098,7 +2107,7 @@ def post_process_map(self, map_primitive, out_tracers, params):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)

Expand All @@ -2108,7 +2117,7 @@ def jvp_jaxpr_thunk(*in_zeros):
nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals)
jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals))
in_avals_ = (*in_avals, *nz_tangent_avals)
jaxpr, _, out_consts = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
return jaxpr, out_consts, out_zeros()

out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
Expand All @@ -2132,7 +2141,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())

main_ = ref(self.main)
Expand Down Expand Up @@ -2170,7 +2179,7 @@ def process_custom_transpose(self, prim, call, tracers, *,
in_avals_t = [*[t.aval for t in tracers_res], *out_types]

with core.new_sublevel():
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
call_jaxpr, out_avals, call_consts, () = trace_to_subjaxpr_dynamic(
call, self.main, in_avals_p)
closed_call_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(call_jaxpr), ())
Expand All @@ -2183,8 +2192,8 @@ def process_custom_transpose(self, prim, call, tracers, *,
@_memoize
def transpose_jaxpr_thunk():
for store in transpose_flat.stores: store.reset()
jaxpr, _, consts = trace_to_subjaxpr_dynamic(transpose_flat, main_(),
in_avals_t)
jaxpr, _, consts, () = trace_to_subjaxpr_dynamic(
transpose_flat, main_(), in_avals_t)
return jaxpr, consts

out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
Expand Down Expand Up @@ -2299,13 +2308,13 @@ def trace_to_jaxpr_dynamic(
debug_info: DebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]:
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
del main, fun
return jaxpr, out_avals, consts
return jaxpr, out_avals, consts, attrs_tracked


def trace_to_subjaxpr_dynamic(
Expand All @@ -2315,7 +2324,7 @@ def trace_to_subjaxpr_dynamic(
*,
keep_inputs: Sequence[bool] | None = None,
debug_info: DebugInfo | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs

frame = JaxprStackFrame()
Expand All @@ -2326,10 +2335,10 @@ def trace_to_subjaxpr_dynamic(
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
ans = fun.call_wrapped(*in_tracers_)
out_tracers = map(trace.full_raise, ans)
jaxpr, consts = frame.to_jaxpr(out_tracers)
jaxpr, consts, attrs_tracked = frame.to_jaxpr(out_tracers)
del fun, main, trace, frame, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked


@profiler.annotate_function
Expand Down Expand Up @@ -2380,7 +2389,7 @@ def trace_to_jaxpr_final(
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
del fun, main
return jaxpr, out_avals, consts
Expand All @@ -2404,6 +2413,15 @@ def trace_to_jaxpr_final2(
tuple[AbstractedAxisName, ...],
]

AttrsTracked = list[tuple[Any, str]]
AttrStates = list
def set_states(attrs_tracked: AttrsTracked, vals: AttrStates):
for ((obj, attr), val) in zip(attrs_tracked, vals):
setattr(obj, attr, val)

def get_states(attrs_tracked: AttrsTracked):
return [getattr(obj, attr) for (obj, attr) in attrs_tracked]


def infer_lambda_input_type(
axes_specs: Sequence[AbstractedAxesSpec] | None,
Expand Down Expand Up @@ -2629,7 +2647,7 @@ def substitute(aval: AbstractValue) -> AbstractValue:

in_avals = [substitute(v.aval) for v in jaxpr.invars]
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
return padded_jaxpr, padded_consts

class BoundedAxisSize(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, False,
primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()

@weakref_lru_cache
Expand Down Expand Up @@ -226,7 +226,7 @@ def _prune_zeros(ts):
return [t for t in ts if type(t) is not ad_util.Zero]

def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
return core.ClosedJaxpr(jaxpr, consts)

def _show_diff(array1, array2):
Expand Down
Loading

0 comments on commit 8bbcbb6

Please sign in to comment.