Skip to content

Commit

Permalink
AutoScale quickstart notebook.
Browse files Browse the repository at this point in the history
Giving a Numpy-like quick presentation, describing the main features in JAX Scaled Arithmetics library.
  • Loading branch information
balancap committed Jan 17, 2024
1 parent a1ea373 commit 56c0646
Showing 1 changed file with 269 additions and 0 deletions.
269 changes: 269 additions & 0 deletions examples/autoscale-quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7c85dead-5274-487c-91ff-7137fbaca393",
"metadata": {},
"source": [
"# JAX Scaled Arithmetics / AutoScale quickstart\n",
"\n",
"**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of\n",
"deep neural networks in low precision (BF16, FP16, FP8) with full scale propagation."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "30940677-4296-40fa-b418-351fcfb62098",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import jax_scaled_arithmetics as jsa"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0e729aa-7a81-4001-8a34-9a00ec822948",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f374e654-97e4-43ef-902a-a890d36a52b9",
"metadata": {},
"outputs": [],
"source": [
"# `autoscale` interpreter is tracing the graph, adding scale propagation where necessary.\n",
"@jsa.autoscale\n",
"def fn(a, b):\n",
" return a + b"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8c59245d-27e5-41a7-bfef-f40849a7b550",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INPUTS: [1. 2.] [3. 6.]\n",
"OUTPUT: [4. 8.] <class 'jaxlib.xla_extension.DeviceArray'>\n"
]
}
],
"source": [
"# Let's start with standard JAX inputs\n",
"a = np.array([1, 2], np.float16)\n",
"b = np.array([3, 6], np.float16)\n",
"out = fn(a, b)\n",
"\n",
"print(\"INPUTS:\", a, b)\n",
"# No scaled arithmetics => \"normal\" JAX mode.\n",
"print(\"OUTPUT:\", out, type(out))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e60cf138-d92d-4ab9-89d4-bacc9e28c39f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e7efaa2e-00a1-40e8-9bbb-685edc975636",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SCALED inputs: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ScaledArray(data=array([1.5, 3. ], dtype=float16), scale=2.0)\n",
"UNSCALED inputs: [1. 2.] [3. 6.]\n"
]
}
],
"source": [
"# Let's create input scaled arrays.\n",
"# NOTE: scale dtype does not have to match core data dtype.\n",
"sa = jsa.as_scaled_array(a, scale=np.float32(1))\n",
"sb = jsa.as_scaled_array(b, scale=np.float32(2))\n",
"\n",
"print(\"SCALED inputs:\", sa, sb)\n",
"# `as_scaled_array` does not change the value of tensor:\n",
"print(\"UNSCALED inputs:\", np.asarray(sa), np.asarray(sb))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1f457243-a0b8-4e4d-b45d-7444d0566b37",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SCALED OUTPUT: ScaledArray(data=DeviceArray([2., 4.], dtype=float16), scale=DeviceArray(2., dtype=float32))\n",
"No scale rounding: ScaledArray(data=DeviceArray([1.789, 3.578], dtype=float16), scale=DeviceArray(2.236068, dtype=float32))\n"
]
}
],
"source": [
"# Running `fn` on scaled arrays triggers `autoscale` graph transformation\n",
"sout = fn(sa, sb)\n",
"# NOTE: by default, scale propagation is using power-of-2.\n",
"print(\"SCALED OUTPUT:\", sout)\n",
"\n",
"# To choose a different scale rounding:\n",
"with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n",
" print(\"No scale rounding:\", fn(sa, sb))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2429c10-00d9-44f8-b0b6-a1fdf13ed264",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "307ee27d-6ed2-4ab6-a152-83947dbf47fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RESCALED OUTPUT: ScaledArray(data=DeviceArray([0.5, 1. ], dtype=float16), scale=DeviceArray(8., dtype=float32))\n"
]
},
{
"data": {
"text/plain": [
"functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7fc7b337c4c0>, <function dynamic_rescale_l1_base at 0x7fc7b3380430>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# JAX Scaled Arithmetics offers basic dynamic rescaling methods. e.g.: max, l1, l2\n",
"sout_rescaled = jsa.ops.dynamic_rescale_max(sout)\n",
"print(\"RESCALED OUTPUT:\", sout_rescaled)\n",
"\n",
"# Equivalent methods are available to dynamically rescale gradients:\n",
"jsa.ops.dynamic_rescale_l1_grad"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "32930d15-d7ff-41d1-85be-eee958bb0741",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# NOTE: in normal JAX mode, these rescale operations are no-ops:\n",
"jsa.ops.dynamic_rescale_max(a) is a"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea5942e7-0279-4dc4-a720-b8c7323ab6a1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9920f44a-26e2-4e20-89c3-4e2b2548239f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ScaledArray(data=DeviceArray([16., 20.], dtype=float32), scale=1.0)\n"
]
}
],
"source": [
"import ml_dtypes\n",
"# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n",
"# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n",
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))\n",
"\n",
"@jsa.autoscale\n",
"def cast_fn(v):\n",
" return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n",
"\n",
"sc_fp8 = cast_fn(sc)\n",
"print(sc_fp8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bd7c1d5-4ea2-4ded-a066-818d9146b8a6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 56c0646

Please sign in to comment.