Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Do not DCE the jaxpr in the lowering pass
Browse files Browse the repository at this point in the history
There isn't an obvious reason for doing DCE there.

PiperOrigin-RevId: 680534567
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 30, 2024
1 parent 21fea5b commit b3fca90
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,6 @@ def lower_jaxpr_to_module(
"Only Blocked indexing mode is supported in Mosaic GPU lowering."
)

with grid_mapping.trace_env():
jaxpr, _ = pe.dce_jaxpr(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)

block = (128, 1, 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
Expand Down

0 comments on commit b3fca90

Please sign in to comment.