-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Giving a Numpy-like quick presentation, describing the main features in JAX Scaled Arithmetics library.
- Loading branch information
Showing
1 changed file
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7c85dead-5274-487c-91ff-7137fbaca393", | ||
"metadata": {}, | ||
"source": [ | ||
"# JAX Scaled Arithmetics / AutoScale quickstart\n", | ||
"\n", | ||
"**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of\n", | ||
"deep neural networks in low precision (BF16, FP16, FP8) with full scale propagation." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "30940677-4296-40fa-b418-351fcfb62098", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import jax\n", | ||
"import jax_scaled_arithmetics as jsa" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e0e729aa-7a81-4001-8a34-9a00ec822948", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "f374e654-97e4-43ef-902a-a890d36a52b9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# `autoscale` interpreter is tracing the graph, adding scale propagation where necessary.\n", | ||
"@jsa.autoscale\n", | ||
"def fn(a, b):\n", | ||
" return a + b" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "8c59245d-27e5-41a7-bfef-f40849a7b550", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"INPUTS: [1. 2.] [3. 6.]\n", | ||
"OUTPUT: [4. 8.] <class 'jaxlib.xla_extension.DeviceArray'>\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Let's start with standard JAX inputs\n", | ||
"a = np.array([1, 2], np.float16)\n", | ||
"b = np.array([3, 6], np.float16)\n", | ||
"out = fn(a, b)\n", | ||
"\n", | ||
"print(\"INPUTS:\", a, b)\n", | ||
"# No scaled arithmetics => \"normal\" JAX mode.\n", | ||
"print(\"OUTPUT:\", out, type(out))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e60cf138-d92d-4ab9-89d4-bacc9e28c39f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "e7efaa2e-00a1-40e8-9bbb-685edc975636", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"SCALED inputs: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ScaledArray(data=array([1.5, 3. ], dtype=float16), scale=2.0)\n", | ||
"UNSCALED inputs: [1. 2.] [3. 6.]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Let's create input scaled arrays.\n", | ||
"# NOTE: scale dtype does not have to match core data dtype.\n", | ||
"sa = jsa.as_scaled_array(a, scale=np.float32(1))\n", | ||
"sb = jsa.as_scaled_array(b, scale=np.float32(2))\n", | ||
"\n", | ||
"print(\"SCALED inputs:\", sa, sb)\n", | ||
"# `as_scaled_array` does not change the value of tensor:\n", | ||
"print(\"UNSCALED inputs:\", np.asarray(sa), np.asarray(sb))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "1f457243-a0b8-4e4d-b45d-7444d0566b37", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"SCALED OUTPUT: ScaledArray(data=DeviceArray([2., 4.], dtype=float16), scale=DeviceArray(2., dtype=float32))\n", | ||
"No scale rounding: ScaledArray(data=DeviceArray([1.789, 3.578], dtype=float16), scale=DeviceArray(2.236068, dtype=float32))\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Running `fn` on scaled arrays triggers `autoscale` graph transformation\n", | ||
"sout = fn(sa, sb)\n", | ||
"# NOTE: by default, scale propagation is using power-of-2.\n", | ||
"print(\"SCALED OUTPUT:\", sout)\n", | ||
"\n", | ||
"# To choose a different scale rounding:\n", | ||
"with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n", | ||
" print(\"No scale rounding:\", fn(sa, sb))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c2429c10-00d9-44f8-b0b6-a1fdf13ed264", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "307ee27d-6ed2-4ab6-a152-83947dbf47fd", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"RESCALED OUTPUT: ScaledArray(data=DeviceArray([0.5, 1. ], dtype=float16), scale=DeviceArray(8., dtype=float32))\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7fc7b337c4c0>, <function dynamic_rescale_l1_base at 0x7fc7b3380430>)" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# JAX Scaled Arithmetics offers basic dynamic rescaling methods. e.g.: max, l1, l2\n", | ||
"sout_rescaled = jsa.ops.dynamic_rescale_max(sout)\n", | ||
"print(\"RESCALED OUTPUT:\", sout_rescaled)\n", | ||
"\n", | ||
"# Equivalent methods are available to dynamically rescale gradients:\n", | ||
"jsa.ops.dynamic_rescale_l1_grad" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "32930d15-d7ff-41d1-85be-eee958bb0741", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"True" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# NOTE: in normal JAX mode, these rescale operations are no-ops:\n", | ||
"jsa.ops.dynamic_rescale_max(a) is a" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ea5942e7-0279-4dc4-a720-b8c7323ab6a1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "9920f44a-26e2-4e20-89c3-4e2b2548239f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"ScaledArray(data=DeviceArray([16., 20.], dtype=float32), scale=1.0)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import ml_dtypes\n", | ||
"# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n", | ||
"# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n", | ||
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))\n", | ||
"\n", | ||
"@jsa.autoscale\n", | ||
"def cast_fn(v):\n", | ||
" return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n", | ||
"\n", | ||
"sc_fp8 = cast_fn(sc)\n", | ||
"print(sc_fp8)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1bd7c1d5-4ea2-4ded-a066-818d9146b8a6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |