-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance degradation of map_coordinates with the new CPU runtime #26152
Comments
Ping @ezhulenev, @penpornk re CPU thunks runtime performance This looks to me like the usual small while loop performance issue, but the lowered HLO doesn't have any loops (they get added when compiled), so there might be something else to look into. See HLO below: Input HLO for grad(fun)
Compiled HLO for grad(fun)
|
This small while loop issue is unknown to me (although I use a lot of while loops and have encountered CPU performance issues with them), is there any documentation I could reference on this? |
Description
With the new CPU runtime, I experience a performance degradation of
jax.scipy.ndimage.map_coordinates
if I combine it withjax.jit
andjax.value_and_grad
. Below is a reproducer with a switch to change between the old and new CPU runtime by setting--xla_cpu_use_thunk_runtime=false
:For the old CPU runtime (
old_cpu_runtime = True
) I obtain:With the new CPU runtime (
old_cpu_runtime = False
) I get:Thus, the new CPU runtime degrades the performance of
map_coordinates
by approximately a factor of 10 when being combined withjax.jit
andjax.value_and_grad
.System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.2.2
python: 3.12.7 (main, Nov 5 2024, 15:03:07) [Clang 16.0.0 (clang-1600.0.26.3)]
device info: cpu, 1 local devices
process_count: 8
uname_result(system='Darwin', node='MacBook-Pro.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec 6 18:41:43 PST 2024, machine='x86_64')
The text was updated successfully, but these errors were encountered: