From 1e071986adc606132286b9b0d71a4eb3af6b976c Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 14 Jun 2024 12:43:25 +0100 Subject: [PATCH] wip --- examples/scalify-quickstart.ipynb | 284 ++++++++++++++++++++++++------ 1 file changed, 234 insertions(+), 50 deletions(-) diff --git a/examples/scalify-quickstart.ipynb b/examples/scalify-quickstart.ipynb index 2fc2287..a5d23fe 100644 --- a/examples/scalify-quickstart.ipynb +++ b/examples/scalify-quickstart.ipynb @@ -10,6 +10,7 @@ "**JAX Scalify** is a library implemeting general scaled arithmetic in JAX, allowing end-to-end scale propagation in computational graphs and easy training/inference of deep neural networks in low precision (mainly FP16 & FP8).\n", "\n", "JAX Scalify supports converting any computational graph into a scaled computational graph, i.e. with `ScaledArray` inputs/outputs.\n", + "\n", "```python\n", "@dataclass\n", "class ScaledArray:\n", @@ -19,21 +20,34 @@ "It fully decouples scale propagation from model definition, allowing easy experimentation and debugging with low precision formats FP16 and FP8." ] }, + { + "cell_type": "markdown", + "id": "39019611", + "metadata": {}, + "source": [ + "## Scaled array representation\n", + "\n", + "In Scalify, every tensor is a `ScaledArray`. This systematic approach simplifies the use of FP16 and FP8 in LLM training, decoupling the scale and numerical stability questions from the high-level model definition. \n", + "\n", + "Below is presented the basics of `ScaledArray` construction, and how it helps representing very large or very small tensors." + ] + }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 70, "id": "30940677-4296-40fa-b418-351fcfb62098", "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import jax\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", "import jax_scalify as jsa" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 81, "id": "e0e729aa-7a81-4001-8a34-9a00ec822948", "metadata": {}, "outputs": [ @@ -41,26 +55,27 @@ "name": "stdout", "output_type": "stream", "text": [ - "`a` : [1. 2.]\n", - "`sa`: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ~ [1. 2.]\n" + "Normal `a`: [1. 2.]\n", + "Scaled `a`: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ~ [1. 2.]\n" ] } ], "source": [ "# Let's start at the beginning: convert an array to a ScaledArray.\n", "a = np.array([1, 2], np.float16)\n", - "# Analogue of `np.asarray`, with passing of the scale value to use.\n", + "# Analogue of `np.asarray`, with in addition passing of the scale to use.\n", + "# NOTE: scale dtype does not have to match core data dtype. Usually using np.float32\n", "sa = jsa.as_scaled_array(a, scale=np.float32(1))\n", "assert isinstance(sa, jsa.ScaledArray)\n", "\n", "# `a` and `sa` represent the same formal tensor.\n", - "print(\"`a` :\", a)\n", - "print(\"`sa`:\", sa, \" ~ \", np.asarray(sa))" + "print(\"Normal `a`:\", a)\n", + "print(\"Scaled `a`:\", sa, \" ~ \", np.asarray(sa))" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 72, "id": "5f624725", "metadata": {}, "outputs": [ @@ -68,77 +83,233 @@ "name": "stdout", "output_type": "stream", "text": [ - "`a` : [1. 2.]\n", - "`sa`: ScaledArray(data=array([2., 4.], dtype=float16), scale=0.5) ~ [1. 2.]\n" + "Normal `a`: [1. 2.]\n", + "Scaled `a`: ScaledArray(data=array([2., 4.], dtype=float16), scale=0.5) ~ [1. 2.]\n" ] } ], "source": [ "# Scalify preserves the semantics of arrays and computational graphs.\n", - "# Passing a different scale does not change the \"value\" of a tensor.\n", + "# Passing a different scale does not change the \"value\" of a represented tensor.\n", "sa = jsa.as_scaled_array(a, scale=np.float32(0.5))\n", "# `a` and `sa` still represent the same formal tensor.\n", - "print(\"`a` :\", a)\n", - "print(\"`sa`:\", sa, \" ~ \", np.asarray(sa))" + "print(\"Normal `a`:\", a)\n", + "print(\"Scaled `a`:\", sa, \" ~ \", np.asarray(sa))" ] }, { "cell_type": "code", - "execution_count": 19, - "id": "f374e654-97e4-43ef-902a-a890d36a52b9", + "execution_count": 73, + "id": "c49c5c55", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<< Scaled Arrays with large values >>\n", + "Normal `a` FP32: [131072. 262144.]\n", + "Normal `a` FP16: [inf inf]\n", + "Scaled `a` FP16: ScaledArray(data=array([1., 2.], dtype=float16), scale=131072.0) ~ [131072. 262144.]\n", + "\n", + "<< Scaled Arrays with small values >>\n", + "Normal `a` FP32: [2.9802322e-08 5.9604645e-08]\n", + "Normal `a` FP16: [0.e+00 6.e-08]\n", + "Scaled `a` FP16: ScaledArray(data=array([0.001953, 0.003906], dtype=float16), scale=1.5258789e-05) ~ [2.9802322e-08 5.9604645e-08]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_392367/4076835521.py:5: RuntimeWarning: overflow encountered in cast\n", + " a_fp16 = a.astype(np.float16)\n" + ] + } + ], + "source": [ + "# Why using Scaled Arrays? => representing very \"small\" or \"large\" tensor.\n", + "# Large FP32 tensor.\n", + "a = np.array([2, 4], np.float32) * 256**2\n", + "# Overflowing to Inf in FP16\n", + "a_fp16 = a.astype(np.float16)\n", + "# \"Properly\" represented with a large scale.\n", + "sa_fp16 = jsa.as_scaled_array(a, scale=np.float32(256**2 * 2)).astype(np.float16)\n", + "\n", + "print(\"<< Scaled Arrays with large values >>\")\n", + "print(\"Normal `a` FP32:\", a)\n", + "print(\"Normal `a` FP16:\", a_fp16)\n", + "# FP16 scaled representation does not overflow. \n", + "print(\"Scaled `a` FP16:\", sa_fp16, \" ~ \", np.asarray(sa_fp16, dtype=np.float32))\n", + "\n", + "a = np.array([2, 4], np.float32) * (256*32)**-2\n", + "a_fp16 = a.astype(np.float16)\n", + "sa_fp16 = jsa.as_scaled_array(a, scale=np.float32(256**-2)).astype(np.float16)\n", + "\n", + "print(\"\\n<< Scaled Arrays with small values >>\")\n", + "print(\"Normal `a` FP32:\", a)\n", + "# Underflowing + loss of precision in sub-normals representation.\n", + "print(\"Normal `a` FP16:\", a_fp16)\n", + "# FP16 scaled representation does not underflow.\n", + "# NOTE: scale factor does not need to be \"perfect\" to keep accurate representation.\n", + "print(\"Scaled `a` FP16:\", sa_fp16, \" ~ \", np.asarray(sa_fp16, dtype=np.float32))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a018d505", "metadata": {}, "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "e91afff9", + "metadata": {}, "source": [ - "# `scalify` interpreter is tracing the graph, adding scale propagation where necessary.\n", - "@jsa.scalify\n", - "def fn(a, b):\n", - " return a + b" + "### Scaled array and FP8 formats\n", + "\n", + "How does it work with FP8? The same `:)`\n", + "Last generation GPUs supports two FP8 formats define by the OCP FP8 specification (https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1):\n", + "* `float8_e4m3fn`: 4 exponent bits, 3 mantissa bits, no infinity;\n", + "* `float8_e5m2fnuz`: 5 exponent bits, 2 mantissa bits, with infinity;\n", + "\n", + "**Note:** there is still on-going IEEE standardization work on FP8 formats (see https://github.com/P3109/Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20Report.pdf). " ] }, { "cell_type": "code", - "execution_count": 3, - "id": "8c59245d-27e5-41a7-bfef-f40849a7b550", + "execution_count": 74, + "id": "aa737550", "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + "FP8-E4M3: Machine parameters for float8_e4m3fn\n", + "---------------------------------------------------------------\n", + "precision = 1 resolution = 1.00e-01\n", + "machep = -3 eps = 1.25e-01\n", + "negep = -4 epsneg = 6.25e-02\n", + "minexp = -6 tiny = 1.56e-02\n", + "maxexp = 9 max = 4.48e+02\n", + "nexp = 4 min = -max\n", + "smallest_normal = 1.56e-02 smallest_subnormal = 1.95e-03\n", + "---------------------------------------------------------------\n", + "\n", + "FP8-E5M2: Machine parameters for float8_e5m2fnuz\n", + "---------------------------------------------------------------\n", + "precision = 1 resolution = 1.00e-01\n", + "machep = -2 eps = 2.50e-01\n", + "negep = -3 epsneg = 1.25e-01\n", + "minexp = -15 tiny = 3.05e-05\n", + "maxexp = 16 max = 5.73e+04\n", + "nexp = 5 min = -max\n", + "smallest_normal = 3.05e-05 smallest_subnormal = 7.63e-06\n", + "---------------------------------------------------------------\n", + "\n" ] - }, + } + ], + "source": [ + "# FP8 formats properties\n", + "print(\"FP8-E4M3:\", jnp.finfo(jnp.float8_e4m3fn))\n", + "print(\"FP8-E5M2:\", jnp.finfo(jnp.float8_e5m2fnuz))" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "70e85309", + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "INPUTS: [1. 2.] [3. 6.]\n", - "OUTPUT: [4. 8.] \n" + "Normal `a` FP32: [400. 448. 512.]\n", + "Normal `a` FP8-E4M3: [384 448 nan]\n", + "Scaled `a` FP8-E4M3: ScaledArray(data=Array([3, 3.5, 4], dtype=float8_e4m3fn), scale=128.0) ~ [384. 448. 512.]\n" ] } ], "source": [ - "# Let's start with standard JAX inputs\n", - "a = np.array([1, 2], np.float16)\n", - "b = np.array([3, 6], np.float16)\n", - "out = fn(a, b)\n", "\n", - "print(\"INPUTS:\", a, b)\n", - "# No scaled arithmetics => \"normal\" JAX mode.\n", - "print(\"OUTPUT:\", out, type(out))" + "a = jnp.array([400, 448, 512], np.float32)\n", + "# Overflowing to NaN as no Inf available.\n", + "a_fp8_e4m3 = a.astype(jnp.float8_e4m3fn)\n", + "# Scaled representation, without overflowing.\n", + "as_fp8_e4m3 = jsa.as_scaled_array(a, scale=np.float32(128)).astype(jnp.float8_e4m3fn)\n", + "\n", + "print(\"Normal `a` FP32:\", a)\n", + "# NOTE: the loss of precision due to 3-bit mantissa.\n", + "print(\"Normal `a` FP8-E4M3:\", a_fp8_e4m3)\n", + "print(\"Scaled `a` FP8-E4M3:\", as_fp8_e4m3, \" ~ \", np.asarray(as_fp8_e4m3, dtype=np.float32))" ] }, { "cell_type": "code", "execution_count": null, - "id": "e60cf138-d92d-4ab9-89d4-bacc9e28c39f", + "id": "ab192562", "metadata": {}, "outputs": [], "source": [] }, + { + "cell_type": "markdown", + "id": "8442121f", + "metadata": {}, + "source": [ + "### Scalify: end-to-end scale propagation\n", + "\n", + "The `scalify` transform is performing end-to-end scale propagation, with application of \"unit-scaling\" type rules." + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "f374e654-97e4-43ef-902a-a890d36a52b9", + "metadata": {}, + "outputs": [], + "source": [ + "# `scalify` transform is tracing the graph, adding scale propagation where necessary.\n", + "@jsa.scalify\n", + "def fn(a, b):\n", + " return a + b" + ] + }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 96, + "id": "8c59245d-27e5-41a7-bfef-f40849a7b550", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUTS: [1. 2.] [3. 6.]\n", + "OUTPUT: [4. 8.] float16 \n" + ] + } + ], + "source": [ + "# Let's start with standard JAX inputs\n", + "a = np.array([1, 2], np.float16)\n", + "b = np.array([3, 6], np.float16)\n", + "# The function `fn` is unchanged with unscaled inputs. \n", + "out = fn(a, b)\n", + "\n", + "print(\"INPUTS:\", a, b)\n", + "# \"Unscaled\" inputs => \"normal\" JAX mode with unscaled outputs\n", + "print(\"OUTPUT:\", out, out.dtype, type(out))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, "id": "e7efaa2e-00a1-40e8-9bbb-685edc975636", "metadata": {}, "outputs": [ @@ -146,25 +317,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "SCALED inputs: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ScaledArray(data=array([1.5, 3. ], dtype=float16), scale=2.0)\n", - "UNSCALED inputs: [1. 2.] [3. 6.]\n" + "Scaled inputs:\n", + "\tScaledArray(data=array([0.5, 1. ], dtype=float16), scale=2.0)\n", + "\tScaledArray(data=array([0.75, 1.5 ], dtype=float16), scale=4.0)\n", + "Equivalent input arrays: [1. 2.] [3. 6.]\n" ] } ], "source": [ "# Let's create input scaled arrays.\n", - "# NOTE: scale dtype does not have to match core data dtype.\n", - "sa = jsa.as_scaled_array(a, scale=np.float32(1))\n", - "sb = jsa.as_scaled_array(b, scale=np.float32(2))\n", + "sa = jsa.as_scaled_array(a, scale=np.float32(2))\n", + "sb = jsa.as_scaled_array(b, scale=np.float32(4))\n", "\n", - "print(\"SCALED inputs:\", sa, sb)\n", - "# `as_scaled_array` does not change the value of tensor:\n", - "print(\"UNSCALED inputs:\", np.asarray(sa), np.asarray(sb))" + "print(f\"Scaled inputs:\\n\\t{sa}\\n\\t{sb}\")\n", + "# `as_scaled_array` does not change the semantic: same underlying tensor represented.\n", + "print(\"Equivalent input arrays:\", np.asarray(sa), np.asarray(sb))" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 102, "id": "1f457243-a0b8-4e4d-b45d-7444d0566b37", "metadata": {}, "outputs": [ @@ -172,20 +344,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "SCALED OUTPUT: ScaledArray(data=DeviceArray([2., 4.], dtype=float16), scale=DeviceArray(2., dtype=float32))\n", - "No scale rounding: ScaledArray(data=DeviceArray([1.789, 3.578], dtype=float16), scale=DeviceArray(2.236068, dtype=float32))\n" + "Scaled output: ScaledArray(data=Array([1., 2.], dtype=float16), scale=Array(4., dtype=float32))\n", + "Equivalent unscaled output: [4. 8.]\n", + "\n", + "Scaled output without scale rounding: ScaledArray(data=Array([0.8945, 1.789 ], dtype=float16), scale=Array(4.472136, dtype=float32))\n" ] } ], "source": [ - "# Running `fn` on scaled arrays triggers `scalify` graph transformation\n", + "# Running `fn` on scaled arrays triggers `scalify` graph transformation & scale propagtion\n", "sout = fn(sa, sb)\n", "# NOTE: by default, scale propagation is using power-of-2.\n", - "print(\"SCALED OUTPUT:\", sout)\n", + "assert isinstance(sout, jsa.ScaledArray)\n", + "print(\"Scaled output:\", sout)\n", + "print(\"Equivalent unscaled output:\", np.asarray(sout))\n", "\n", "# To choose a different scale rounding:\n", "with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n", - " print(\"No scale rounding:\", fn(sa, sb))" + " print(\"\\nScaled output without scale rounding:\", fn(sa, sb))" ] }, { @@ -294,6 +470,14 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0e6a804", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {