diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 1fde02a14655..8e25807dcb95 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -60,3 +60,9 @@ Common causes of OOM failures **Running JAX on the display GPU.** Use :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` or :code:`XLA_PYTHON_CLIENT_PREALLOCATE`. + +**Disabling rematerialization HLO pass** + Sometimes disabling the rematerialization HLO pass is favorable to avoid + poor remat choices by the compiler. The pass can be disabled by + :code:`jax.config.update('enable_remat_opt_pass', False)`. But this can + sometimes lead to OOM failures.