Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Aug 20, 2024
1 parent 22c2cc9 commit 5f209de
Showing 1 changed file with 62 additions and 8 deletions.
70 changes: 62 additions & 8 deletions docs/JAX FP8 matmul tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 2,
"id": "fb62c752-f7ba-4714-8605-88e2afcff88f",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -115,10 +115,17 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 3,
"id": "9be90f27-5520-45f6-a42d-b309572e6e91",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -146,10 +153,58 @@
"print(\"E4M3 @ E5M2 FP8 matmul output:\", c.aval)"
]
},
{
"cell_type": "markdown",
"id": "d08b0c24-1ac5-458b-a9ae-e269ec34862e",
"metadata": {},
"source": [
"## FP8 compiled HLO\n",
"\n",
"Let's have a look at the compiled HLO module generated by JAX + XLA. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7edfa758-bf4e-49fa-8c5d-5dc9c0c2c346",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e5m2[64,128]{1,0})->f8e5m2[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}\n",
"\n",
"ENTRY %main.4 (Arg_0.1: f8e4m3fn[32,64], Arg_1.2: f8e5m2[64,128]) -> f8e5m2[32,128] {\n",
" %Arg_1.2 = f8e5m2[64,128]{1,0} parameter(1)\n",
" %convert.5 = f32[64,128]{1,0} convert(f8e5m2[64,128]{1,0} %Arg_1.2)\n",
" %Arg_0.1 = f8e4m3fn[32,64]{1,0} parameter(0)\n",
" %convert.4 = f32[32,64]{1,0} convert(f8e4m3fn[32,64]{1,0} %Arg_0.1)\n",
" %dot.0 = f32[32,128]{1,0} dot(f32[32,64]{1,0} %convert.4, f32[64,128]{1,0} %convert.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n",
" ROOT %convert.3 = f8e5m2[32,128]{1,0} convert(f32[32,128]{1,0} %dot.0)\n",
"}\n",
"\n",
"\n"
]
}
],
"source": [
"from jax_scalify.utils import print_hlo_module\n",
"\n",
"def matmul_fn(a_fp8, b_fp8):\n",
" # FP8 x FP8 -> FP8 matmul\n",
" return jax.lax.dot(a_fp8, b_fp8)\n",
"\n",
"# AOT compilation with JAX, inspecting the (final) HLO module generated.\n",
"fn_compiled = jax.jit(matmul_fn).lower(a, b).compile()\n",
"# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config).\n",
"print_hlo_module(fn_compiled, backend_cfg=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a617ad86-5570-4792-bbf6-c1f70dba8d3f",
"id": "72d805ea-89b6-457d-9558-ff31fdd23d35",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -172,8 +227,7 @@
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from jax_scalify.utils import print_hlo_module"
"import jax.numpy as jnp\n"
]
},
{
Expand All @@ -191,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 7,
"id": "a6142f8d-08ee-4fa6-962f-2b85a1bcecb6",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -220,8 +274,8 @@
"\n",
"# AOT compilation with JAX, inspecting the (final) HLO module generated.\n",
"fn_compiled = jax.jit(matmul_fn).lower(a_aval, b_aval).compile()\n",
"# (Human readable) optimized Hlo module generated by XLA.\n",
"print_hlo_module(fn_compiled, backend_cfg=True)"
"# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)\n",
"print_hlo_module(fn_compiled, backend_cfg=False)"
]
},
{
Expand Down

0 comments on commit 5f209de

Please sign in to comment.