Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jun 13, 2024
1 parent 41f209e commit 8ce29e2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# JAX Scalify: end-to-end scaled Arithmetics
# JAX Scalify: end-to-end scaled arithmetic

**JAX Scalify** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@
"id": "7c85dead-5274-487c-91ff-7137fbaca393",
"metadata": {},
"source": [
"# JAX Scaled Arithmetics / AutoScale quickstart\n",
"# JAX Scalify: Quickstart on end-to-end scaled arithmetic\n",
"\n",
"**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of\n",
"deep neural networks in low precision (BF16, FP16, FP8) with full scale propagation."
"**JAX Scalify** is a library implemeting general scaled arithmetic in JAX, allowing end-to-end scale propagation in computational graphs and easy training/inference of deep neural networks in low precision (mainly FP16 & FP8).\n",
"\n",
"JAX Scalify supports converting any computational graph into a scaled computational graph, i.e. with `ScaledArray` inputs/outputs.\n",
"```python\n",
"@dataclass\n",
"class ScaledArray:\n",
" data: Array\n",
" scale: Array\n",
"```\n",
"It fully decouples scale propagation from model definition, allowing easy experimentation and debugging with low precision formats FP16 and FP8."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 16,
"id": "30940677-4296-40fa-b418-351fcfb62098",
"metadata": {},
"outputs": [],
Expand All @@ -25,15 +33,58 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"id": "e0e729aa-7a81-4001-8a34-9a00ec822948",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`a` : [1. 2.]\n",
"`sa`: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ~ [1. 2.]\n"
]
}
],
"source": [
"# Let's start at the beginning: convert an array to a ScaledArray.\n",
"a = np.array([1, 2], np.float16)\n",
"# Analogue of `np.asarray`, with passing of the scale value to use.\n",
"sa = jsa.as_scaled_array(a, scale=np.float32(1))\n",
"assert isinstance(sa, jsa.ScaledArray)\n",
"\n",
"# `a` and `sa` represent the same formal tensor.\n",
"print(\"`a` :\", a)\n",
"print(\"`sa`:\", sa, \" ~ \", np.asarray(sa))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5f624725",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`a` : [1. 2.]\n",
"`sa`: ScaledArray(data=array([2., 4.], dtype=float16), scale=0.5) ~ [1. 2.]\n"
]
}
],
"source": [
"# Scalify preserves the semantics of arrays and computational graphs.\n",
"# Passing a different scale does not change the \"value\" of a tensor.\n",
"sa = jsa.as_scaled_array(a, scale=np.float32(0.5))\n",
"# `a` and `sa` still represent the same formal tensor.\n",
"print(\"`a` :\", a)\n",
"print(\"`sa`:\", sa, \" ~ \", np.asarray(sa))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 19,
"id": "f374e654-97e4-43ef-902a-a890d36a52b9",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -261,7 +312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 8ce29e2

Please sign in to comment.