From 3358d826469a31ac237568478cd9f5352978621b Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Thu, 29 Feb 2024 17:30:19 -0800 Subject: [PATCH] Update `CompiledMemoryStats` in xla/python to include host memory stats and add a few tests to memories_test.py PiperOrigin-RevId: 611649314 --- tests/memories_test.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 5c0cff5039d2..62c605ba210c 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -23,6 +23,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import xla_extension_version from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp @@ -1111,12 +1112,19 @@ def f(x): f = jax.jit(jax.grad(f)) f(inp) # doesn't crash - compiled_text = f.lower(inp).compile().as_text() + compiled_f = f.lower(inp).compile() + + compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) self.assertRegex(compiled_text, r"copy-start.*S\(5\)") self.assertRegex(compiled_text, r"copy-done.*S\(5\)") + compiled_stats = compiled_f.memory_analysis() + if compiled_stats is not None: + if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + def test_remat_scan_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) shape = (256, 128) @@ -1161,12 +1169,19 @@ def g(ys, _): f = jax.jit(jax.grad(f)) f(inp) # doesn't crash - compiled_text = f.lower(inp).compile().as_text() + compiled_f = f.lower(inp).compile() + + compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + compiled_stats = compiled_f.memory_analysis() + if compiled_stats is not None: + if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + def test_remat_scan_layout_change_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) shape = (256, 128) @@ -1194,12 +1209,19 @@ def g(ys, _): f = jax.jit(jax.grad(f)) f(inp) # doesn't crash - compiled_text = f.lower(inp).compile().as_text() + compiled_f = f.lower(inp).compile() + + compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + compiled_stats = compiled_f.memory_analysis() + if compiled_stats is not None: + if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + def test_remat_checkpoint_dots_with_no_batch_dims(self): policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( "device", "pinned_host") @@ -1219,12 +1241,18 @@ def f(x): f = jax.jit(jax.grad(f)) f(inp) # doesn't crash - compiled_text = f.lower(inp).compile().as_text() + compiled_f = f.lower(inp).compile() + + compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) self.assertRegex(compiled_text, r"copy-start.*S\(5\)") self.assertRegex(compiled_text, r"copy-done.*S\(5\)") + compiled_stats = compiled_f.memory_analysis() + if compiled_stats is not None: + if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43): + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())