Skip to content

Commit

Permalink
document jax config to disable remat HLO pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Keshav Balasubramanian committed Oct 18, 2024
1 parent ade480f commit 318d6e8
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/gpu_memory_allocation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,9 @@ Common causes of OOM failures
**Running JAX on the display GPU.**
Use :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` or
:code:`XLA_PYTHON_CLIENT_PREALLOCATE`.

**Disabling rematerialization HLO pass**
Sometimes disabling the rematerialization HLO pass is favorable to avoid
poor remat choices by the compiler. The pass can be disabled by
:code:`jax.config.update('enable_remat_opt_pass', False)`. But this can
sometimes lead to OOM failures.

0 comments on commit 318d6e8

Please sign in to comment.