Skip to content

Commit

Permalink
gonna show andi something
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Sep 21, 2023
1 parent bb54f47 commit cb5b8e5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions pedagogical_examples/xaot_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def simple_timeit(f, tries = 5, verbose = True):
chips_per_host_bounds=(2, 2, 2),
num_slices=1,
).devices
devices = jax.devices() # topology_devices or jax.devices()
devices = topology_devices # topology_devices or jax.devices()
num_devices = len(devices)
print(f"Devices: {devices} (num_devices: {num_devices})")
assert len(devices) > 1, "You must have at least two devices"
Expand Down Expand Up @@ -261,6 +261,7 @@ def xprint(string, verbose):
compiled = lowered.compile()
serialized, in_tree, out_tree = serialize(compiled)

print(f"{type(serialized)=}")
# save the serialized via pickle
xprint("Saving the serialized compiled train step...", verbose)
with open(pickle_filename, "wb") as f:
Expand All @@ -271,13 +272,15 @@ def xprint(string, verbose):
save_xaot = True
use_mesh = Mesh(mesh.devices, mesh.axis_names)
key = jax.random.PRNGKey(0)
fake_key = jax.core.ShapedArray(key.shape, key.dtype)
print(f"{key=}")
if save_xaot:
print("saving gen_data...")
pjit_gen_data, _, _, _, _ = xaot_save(
gen_data,
use_mesh,
'data_sharding.pkl',
key,
fake_key,
in_shardings=None,
out_shardings=data_sharding,
verbose=False
Expand All @@ -289,7 +292,7 @@ def xprint(string, verbose):
gen_layers,
use_mesh,
'layers_sharding.pkl',
key,
fake_key,
in_shardings=None,
out_shardings=parameter_sharding,
verbose=False
Expand Down
3 changes: 2 additions & 1 deletion xaot_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def fun(x):
lowered = jitted.lower(
jax.core.ShapedArray(shape=(128, 128), dtype=np.float32)
)
print(f"{type(lowered)=}")
orig_compiled = lowered.compile()

print(f"{type(orig_compiled)=}")

serialized, in_tree, out_tree = serialize(orig_compiled)
print(f"{in_tree=}")
Expand Down

0 comments on commit cb5b8e5

Please sign in to comment.