Skip to content

Commit

Permalink
need to clean up still
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Oct 9, 2023
1 parent cb5b8e5 commit 9e81d41
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions xaot_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@



topo='v4-8'
topo='v4-16'
if topo=='v4-8':
topology_devices = get_topology_desc(
platform='tpu',
Expand Down Expand Up @@ -52,8 +52,8 @@ def fun(x):
print(f"{type(orig_compiled)=}")

serialized, in_tree, out_tree = serialize(orig_compiled)
print(f"{in_tree=}")
print(f"{out_tree=}")
# print(f"{in_tree=}")
# print(f"{out_tree=}")

with open(f"x_aot_{topo}.pickle", "wb") as f:
pickle.dump(serialized, f)
Expand All @@ -76,8 +76,8 @@ def fun(x):

out_shaped = jax.eval_shape(fun, ex_input)
flat_out_shaped, out_tree_recreated = jax.tree_util.tree_flatten(out_shaped)
print(f"{out_tree_recreated=}")
# print(f"{out_tree_recreated=}")

ex_input = jax.core.ShapedArray(shape=(128, 128), dtype=np.float32)
flat_in_shaped, in_tree_recreated = tree_util.tree_flatten(((ex_input,),{}))
print(f"{in_tree_recreated=}")
# print(f"{in_tree_recreated=}")

0 comments on commit 9e81d41

Please sign in to comment.