diff --git a/09a_JAX_patterns.ipynb b/09a_JAX_patterns.ipynb
index 6d2f05a..ade39fb 100644
--- a/09a_JAX_patterns.ipynb
+++ b/09a_JAX_patterns.ipynb
@@ -1,324 +1,399 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# JAX patterns\n",
- "\n",
- "Since Numpyro uses JAX as a backend, it is important to know how to work with JAX efficiently. JAX is a high-performance numerical computing library designed for machine learning research, which leverages XLA (Accelerated Linear Algebra) for optimizing and compiling numerical computations. This combination enables JAX to efficiently execute large-scale mathematical operations on hardware accelerators, boosting performance and scalability."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Array assignments\n",
- "\n",
- "Even thought, `jax.numpy as jnp` is meant to replicate the functionality of `numpy as np`, and `jnp` would behave in many cases indeed indistinguishably from `np`, there is a few differences. Assignments in array is one them: JAX arrays are immutable.\n",
- "\n",
- "In JAX, array indexing adapts NumPy's syntax to fit immutable arrays. Direct assignments like `array[index] = value` are not supported. Instead, JAX uses the `.at` method for updates, allowing modifications in a purely functional style. For example, setting an element is done with \n",
- "\n",
- "`array = array.at[index].set(value)`, \n",
- "\n",
- "and incrementing an element uses \n",
- "\n",
- "`array = array.at[index].add(value)`. \n",
- "\n",
- "This method returns a new array with the desired change, maintaining the original array unchanged, crucial for JAX’s efficiency in automatic differentiation and optimization."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## `vmap`\n",
- "\n",
- "`vmap` in JAX is a vectorizing map that automatically transforms a function to apply it over batched inputs, effectively parallelizing computations across data. This tool helps simplify the process of extending single-data-point functions to operate on batches, improving efficiency and performance without manual adjustments to the code.\n",
- "\n",
- "`vmap` returns a function which is applied acorss the leading axis of an array.\n",
- "\n",
- "\n",
- "Let's say you have a simple function that computes the square of a number. Using `vmap`, you can easily extend this function to square an entire array of numbers in one go. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[[1 1 1 1 1]\n",
- " [4 4 4 4 4]]\n"
- ]
- }
- ],
- "source": [
- "import jax\n",
- "import jax.numpy as jnp\n",
- "from jax import vmap\n",
- "from jax import lax \n",
- "\n",
- "\n",
- "def square(x):\n",
- " return x * x\n",
- "\n",
- "array = jnp.array([jnp.repeat(1,5), jnp.repeat(2,5)])\n",
- "\n",
- "squared_array = vmap(square)(array)\n",
- "\n",
- "print(squared_array)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can use the option `in_axes` to specify along whch axis to apply the function. Not the difference with the `axes` option for the in-built operations in `numpy` and `jax.numpy:`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[3 3 3 3 3]\n",
- "[ 5 10]\n"
- ]
- }
- ],
- "source": [
- "# here `axis=k` says which dimension to collapse \n",
- "print(jnp.sum(array, axis=0))\n",
- "\n",
- "print(jnp.sum(array, axis=1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[ 5 10]\n",
- "[3 3 3 3 3]\n"
- ]
- }
- ],
- "source": [
- "def sum_array(x):\n",
- " return jnp.sum(x)\n",
- "\n",
- "# default behavior is to sum along the first axis\n",
- "print(jax.vmap(sum_array, in_axes=0)(array))\n",
- "\n",
- "print(jax.vmap(sum_array, in_axes = 1)(array))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Consider a scenario where you have a function that calculates the Euclidean distance between two points, but you want to apply this function across pairs of points stored in two separate arrays. This function involves more complex operations, including subtraction and squaring, which makes `vmap` particularly beneficial for vectorizing such computations efficiently over batches."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[5.8309517 1.4142135 5.8309517]\n"
- ]
- }
- ],
- "source": [
- "def euclidean_distance(x, y):\n",
- " return jnp.sqrt(jnp.sum((x - y) ** 2))\n",
- "\n",
- "points_1 = jnp.array([[1, 2], [3, 4], [5, 6]])\n",
- "points_2 = jnp.array([[6, 5], [4, 3], [2, 1]])\n",
- "\n",
- "distances = vmap(euclidean_distance, in_axes=(0, 0))(points_1, points_2)\n",
- "\n",
- "print(distances)"
- ]
- },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# JAX patterns\n",
+ "\n",
+ "Since Numpyro uses JAX as a backend, it is important to know how to work with JAX efficiently. JAX is a high-performance numerical computing library designed for machine learning research, which leverages XLA (Accelerated Linear Algebra) for optimizing and compiling numerical computations. This combination enables JAX to efficiently execute large-scale mathematical operations on hardware accelerators, boosting performance and scalability."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Array assignments\n",
+ "\n",
+ "Even thought, `jax.numpy as jnp` is meant to replicate the functionality of `numpy as np`, and `jnp` would behave in many cases indeed indistinguishably from `np`, there is a few differences. The reason for that is that `jax.numpy` provides nearly the same API as `numpy`, it uses a different (JAX) backend. \n",
+ "\n",
+ "JAX arrays are immutable. I.e. in JAX, array indexing adapts NumPy's syntax to fit immutable arrays. Direct assignments like `array[index] = value` are not supported. Instead, JAX uses the `.at` method for updates, allowing modifications in a purely functional style. For example, setting value of an element is done with \n",
+ "\n",
+ "`array = array.at[index].set(value)`, \n",
+ "\n",
+ "and incrementing an element uses \n",
+ "\n",
+ "`array = array.at[index].add(value)`. \n",
+ "\n",
+ "This method returns a new array with the desired change, maintaining the original array unchanged, crucial for JAX’s efficiency in automatic differentiation and optimisation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## `vmap`\n",
+ "\n",
+ "`vmap` in JAX is a vectorizing map that automatically transforms a function to apply it over batched inputs, effectively parallelizing computations across data. This tool helps simplify the process of extending single-data-point functions to operate on batches, improving efficiency and performance without manual adjustments to the code.\n",
+ "\n",
+ "`vmap` returns a function which is applied acorss the leading axis of an array.\n",
+ "\n",
+ "\n",
+ "Let's say you have a simple function that computes the square of a number. Using `vmap`, you can easily extend this function to square an entire array of numbers in one go. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `in_axes=(0, 0)` argument tells vmap to apply the function across the first dimension (rows) of both inputs.\n",
- "\n",
- "Let's look at nother example of a multidimensional input."
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[1 1 1 1 1]\n",
+ " [4 4 4 4 4]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "from jax import vmap\n",
+ "from jax import lax \n",
+ "\n",
+ "\n",
+ "def square(x):\n",
+ " return x * x\n",
+ "\n",
+ "array = jnp.array([jnp.repeat(1,5), jnp.repeat(2,5)])\n",
+ "\n",
+ "squared_array = vmap(square)(array)\n",
+ "\n",
+ "print(squared_array)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can use the option `in_axes` to specify along whch axis to apply the function. Not the difference with the `axes` option for the in-built operations in `numpy` and `jax.numpy:`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[[10. 15. 20.]\n",
- " [14. 21. 28.]\n",
- " [18. 27. 36.]]\n"
- ]
- }
- ],
- "source": [
- "def add_and_multiply_scalar(x, y, scalar):\n",
- " return (x + y) * scalar\n",
- "\n",
- "def add_and_multiply_vector(x, y, scalar):\n",
- " return vmap(add_and_multiply_scalar, (0, 0, None))(x, y, scalar)\n",
- "\n",
- "x = jnp.array([1.0, 2.0, 3.0])\n",
- "y = jnp.array([4.0, 5.0, 6.0])\n",
- "scalar = jnp.array([2.0, 3.0, 4.0])\n",
- "\n",
- "result = add_and_multiply_vector(x, y, scalar)\n",
- "print(result)"
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[3 3 3 3 3]\n",
+ "[ 5 10]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# here `axis=k` says which dimension to collapse \n",
+ "print(jnp.sum(array, axis=0))\n",
+ "\n",
+ "print(jnp.sum(array, axis=1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " Here `(0, 0, None)` tells `vmap` to apply the function `add_and_multiply_scalar` element-wise across the first dimension of `x` and `y`, while keeping the `scalar` fixed for each corresponding element-wise operation.\n",
- "\n",
- " `vmap` applies a function to the elements of th leading axis of an array indendently. What if there is some carry over quantities that we need to use for the next copmutation based on last? Then we use `lax.scan`."
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 5 10]\n",
+ "[3 3 3 3 3]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def sum_array(x):\n",
+ " return jnp.sum(x)\n",
+ "\n",
+ "# default behavior is to sum along the first axis\n",
+ "print(jax.vmap(sum_array, in_axes=0)(array))\n",
+ "\n",
+ "print(jax.vmap(sum_array, in_axes = 1)(array))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Consider a scenario where you have a function that calculates the Euclidean distance between two points, but you want to apply this function across pairs of points stored in two separate arrays. This function involves more complex operations, including subtraction and squaring, which makes `vmap` particularly beneficial for vectorising such computations efficiently over batches."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## `lax.scan`\n",
- "\n",
- "[`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) is a function used for looping over sequences while carrying state across iterations. Its inputs are as follows:\n",
- "\n",
- "- an input sequences that you want to iterate over: a list, tuple or an array,\n",
- "\n",
- "- an initial state which is carried through each iteration of the loop,\n",
- "\n",
- "- a function that takes the current state and an element from the input sequence and returns a tuple: the new state for the next iteration and the output for the current iteration,\n",
- "\n",
- "`lax.scan` then iterates over the sequence, applying the function at each step, and updating the state accordingly. It collects the outputs from each iteration into a sequence. It returns both the sequence of outputs and the final state after iterating over all elements in the input sequences."
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[5.8309517 1.4142135 5.8309517]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def euclidean_distance(x, y):\n",
+ " return jnp.sqrt(jnp.sum((x - y) ** 2))\n",
+ "\n",
+ "points_1 = jnp.array([[1, 2], [3, 4], [5, 6]])\n",
+ "points_2 = jnp.array([[6, 5], [4, 3], [2, 1]])\n",
+ "\n",
+ "distances = vmap(euclidean_distance, in_axes=(0, 0))(points_1, points_2)\n",
+ "\n",
+ "print(distances)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The `in_axes=(0, 0)` argument tells `vmap` to apply the function across the first dimension (rows) of both inputs.\n",
+ "\n",
+ "Let's look at nother example of a multidimensional input."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Example: cumulative sum"
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[10. 15. 20.]\n",
+ " [14. 21. 28.]\n",
+ " [18. 27. 36.]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def add_and_multiply_scalar(x, y, scalar):\n",
+ " return (x + y) * scalar\n",
+ "\n",
+ "def add_and_multiply_vector(x, y, scalar):\n",
+ " return vmap(add_and_multiply_scalar, (0, 0, None))(x, y, scalar)\n",
+ "\n",
+ "x = jnp.array([1.0, 2.0, 3.0])\n",
+ "y = jnp.array([4.0, 5.0, 6.0])\n",
+ "scalar = jnp.array([2.0, 3.0, 4.0])\n",
+ "\n",
+ "result = add_and_multiply_vector(x, y, scalar)\n",
+ "print(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Here `(0, 0, None)` tells `vmap` to apply the function `add_and_multiply_scalar` element-wise across the first dimension of `x` and `y`, while keeping the `scalar` fixed for each corresponding element-wise operation.\n",
+ "\n",
+ " `vmap` applies a function to the elements of th leading axis of an array indendently. What if there is some carry over quantities that we need to use for the next copmutation based on last? Then we use `lax.scan`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## `lax.scan`\n",
+ "\n",
+ "[`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) is a function used for looping over sequences while carrying state across iterations. Its inputs are as follows:\n",
+ "\n",
+ "- an input sequences that you want to iterate over: a list, tuple or an array,\n",
+ "\n",
+ "- an initial state which is carried through each iteration of the loop,\n",
+ "\n",
+ "- a function that takes the current state and an element from the input sequence and returns a tuple: the new state for the next iteration and the output for the current iteration,\n",
+ "\n",
+ "`lax.scan` then iterates over the sequence, applying the function at each step, and updating the state accordingly. It collects the outputs from each iteration into a sequence. It returns both the sequence of outputs and the final state after iterating over all elements in the input sequences."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Example: cumulative sum"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Cumulative Sum: 15\n"
- ]
- }
- ],
- "source": [
- "def cumsum_fn(prev_sum, next_value):\n",
- " return prev_sum + next_value, prev_sum + next_value\n",
- "\n",
- "sequence = jnp.array([1, 2, 3, 4, 5])\n",
- "\n",
- "initial_sum = jnp.array(0)\n",
- "\n",
- "cumulative_sum, _ = lax.scan(cumsum_fn, initial_sum, sequence)\n",
- "\n",
- "print(\"Cumulative Sum:\", cumulative_sum)\n"
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Cumulative Sum: 15\n"
+ ]
+ }
+ ],
+ "source": [
+ "def cumsum_fn(prev_sum, next_value):\n",
+ " return prev_sum + next_value, prev_sum + next_value\n",
+ "\n",
+ "sequence = jnp.array([1, 2, 3, 4, 5])\n",
+ "\n",
+ "initial_sum = jnp.array(0)\n",
+ "\n",
+ "cumulative_sum, _ = lax.scan(cumsum_fn, initial_sum, sequence)\n",
+ "\n",
+ "print(\"Cumulative Sum:\", cumulative_sum)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Example: moving average"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Example: moving average"
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Moving averages: [ 3.3333335 6.222223 8.8148155 11.209877 13.473251 ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def moving_average(state, next_value):\n",
+ " prev_avg, _ = state \n",
+ " new_avg = (prev_avg * 2 + next_value) / 3\n",
+ " return (new_avg, next_value), new_avg\n",
+ "\n",
+ "initial_state = (0.0, 0.0)\n",
+ "\n",
+ "sensor_data = jnp.array([10.0, 12.0, 14.0, 16.0, 18.0])\n",
+ "\n",
+ "final_state, moving_averages = lax.scan(moving_average, initial_state, sensor_data)\n",
+ "\n",
+ "print(\"Moving averages:\", moving_averages)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Random keys with `jax.random.PRNGKey(seed)`\n",
+ "\n",
+ "We have already used at a few instances JAX's [pseudorandom numbers](https://jax.readthedocs.io/en/latest/random-numbers.html) generator. `jax.random.PRNGKey(seed)` in JAX creates a random number generator key using a given seed (an integer). This key is used for reproducible and functional-style random number generation. JAX’s random system is stateless, so you need to manage and pass this key explicitly to any function that uses randomness. You can use this key to generate random numbers or split it into sub-keys:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Moving averages: [ 3.3333335 6.222223 8.8148155 11.209877 13.473251 ]\n"
- ]
- }
- ],
- "source": [
- "def moving_average(state, next_value):\n",
- " prev_avg, _ = state \n",
- " new_avg = (prev_avg * 2 + next_value) / 3\n",
- " return (new_avg, next_value), new_avg\n",
- "\n",
- "initial_state = (0.0, 0.0)\n",
- "\n",
- "sensor_data = jnp.array([10.0, 12.0, 14.0, 16.0, 18.0])\n",
- "\n",
- "final_state, moving_averages = lax.scan(moving_average, initial_state, sensor_data)\n",
- "\n",
- "print(\"Moving averages:\", moving_averages)\n"
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 0 42]\n"
+ ]
+ }
+ ],
+ "source": [
+ "key = jax.random.PRNGKey(42)\n",
+ "print(key)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2465931498 3679230171] [255383827 267815257]\n"
+ ]
}
- ],
- "metadata": {
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.18"
+ ],
+ "source": [
+ "subkey1, subkey2 = jax.random.split(key)\n",
+ "print(subkey1, subkey2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Splitting into multiple keys can be done using the `num` argument:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[2765691542 1385194879]\n",
+ " [ 831049250 3807460095]\n",
+ " [3616728933 824333390]\n",
+ " [1482326074 1083977345]\n",
+ " [2713995981 2812206153]]\n"
+ ]
}
+ ],
+ "source": [
+ "keys = jax.random.split(key, num=5) \n",
+ "print(keys)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "aims",
+ "language": "python",
+ "name": "python3"
},
- "nbformat": 4,
- "nbformat_minor": 0
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}