Skip to content

Commit

Permalink
Update CompiledMemoryStats in xla/python to include host memory sta…
Browse files Browse the repository at this point in the history
…ts and add a few tests to memories_test.py

PiperOrigin-RevId: 611649314
  • Loading branch information
yueshengys authored and jax authors committed Mar 1, 2024
1 parent 48e6e0d commit 3358d82
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_extension_version
from jax._src import config
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
Expand Down Expand Up @@ -1111,12 +1112,19 @@ def f(x):
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash

compiled_text = f.lower(inp).compile().as_text()
compiled_f = f.lower(inp).compile()

compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")

compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_remat_scan_jaxpr_offloadable(self):
mesh = jtu.create_global_mesh((2,), ("x",))
shape = (256, 128)
Expand Down Expand Up @@ -1161,12 +1169,19 @@ def g(ys, _):
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash

compiled_text = f.lower(inp).compile().as_text()
compiled_f = f.lower(inp).compile()

compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")

compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_remat_scan_layout_change_offloadable(self):
mesh = jtu.create_global_mesh((2,), ("x",))
shape = (256, 128)
Expand Down Expand Up @@ -1194,12 +1209,19 @@ def g(ys, _):
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash

compiled_text = f.lower(inp).compile().as_text()
compiled_f = f.lower(inp).compile()

compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")

compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_remat_checkpoint_dots_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
Expand All @@ -1219,12 +1241,18 @@ def f(x):
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash

compiled_text = f.lower(inp).compile().as_text()
compiled_f = f.lower(inp).compile()

compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")

compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3358d82

Please sign in to comment.