Skip to content

Commit

Permalink
Merge pull request #2098 from mattjj:avoid-units-axes-scan
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 445929745
  • Loading branch information
Flax Authors committed May 2, 2022
2 parents 8e05b0c + 9063cba commit 0d673cb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def body_fn(c, xs, init_mode=False):
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
lu.wrap_init(broadcast_body), in_tree)
in_pvals = list(map(pe.PartialVal.unknown, in_avals))
_, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)

out_flat = []
for pv, const in out_pvals:
Expand Down

0 comments on commit 0d673cb

Please sign in to comment.