From cb5b8e5dd3ac1249f52f239b777a1eef88eacb9d Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Thu, 21 Sep 2023 20:35:30 +0000 Subject: [PATCH] gonna show andi something --- pedagogical_examples/xaot_shardings.py | 9 ++++++--- xaot_minimal.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pedagogical_examples/xaot_shardings.py b/pedagogical_examples/xaot_shardings.py index b9aae3390..853d74202 100644 --- a/pedagogical_examples/xaot_shardings.py +++ b/pedagogical_examples/xaot_shardings.py @@ -152,7 +152,7 @@ def simple_timeit(f, tries = 5, verbose = True): chips_per_host_bounds=(2, 2, 2), num_slices=1, ).devices -devices = jax.devices() # topology_devices or jax.devices() +devices = topology_devices # topology_devices or jax.devices() num_devices = len(devices) print(f"Devices: {devices} (num_devices: {num_devices})") assert len(devices) > 1, "You must have at least two devices" @@ -261,6 +261,7 @@ def xprint(string, verbose): compiled = lowered.compile() serialized, in_tree, out_tree = serialize(compiled) + print(f"{type(serialized)=}") # save the serialized via pickle xprint("Saving the serialized compiled train step...", verbose) with open(pickle_filename, "wb") as f: @@ -271,13 +272,15 @@ def xprint(string, verbose): save_xaot = True use_mesh = Mesh(mesh.devices, mesh.axis_names) key = jax.random.PRNGKey(0) +fake_key = jax.core.ShapedArray(key.shape, key.dtype) +print(f"{key=}") if save_xaot: print("saving gen_data...") pjit_gen_data, _, _, _, _ = xaot_save( gen_data, use_mesh, 'data_sharding.pkl', - key, + fake_key, in_shardings=None, out_shardings=data_sharding, verbose=False @@ -289,7 +292,7 @@ def xprint(string, verbose): gen_layers, use_mesh, 'layers_sharding.pkl', - key, + fake_key, in_shardings=None, out_shardings=parameter_sharding, verbose=False diff --git a/xaot_minimal.py b/xaot_minimal.py index 2b5d10b79..c6fd72ba1 100644 --- a/xaot_minimal.py +++ b/xaot_minimal.py @@ -47,8 +47,9 @@ def fun(x): lowered = jitted.lower( jax.core.ShapedArray(shape=(128, 128), dtype=np.float32) ) +print(f"{type(lowered)=}") orig_compiled = lowered.compile() - +print(f"{type(orig_compiled)=}") serialized, in_tree, out_tree = serialize(orig_compiled) print(f"{in_tree=}")