diff --git a/src/qrisp/examples/jasp_demo.ipynb b/src/qrisp/examples/jasp_demo.ipynb new file mode 100644 index 00000000..00818003 --- /dev/null +++ b/src/qrisp/examples/jasp_demo.ipynb @@ -0,0 +1,723 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c2001150-1fef-4784-b47e-4ab75a18d68d", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install git+https://github.com/eclipse-qrisp/Qrisp.git@catalyst_integration" + ] + }, + { + "cell_type": "markdown", + "id": "9918bf03-b930-4c7f-adec-a9bafcc102de", + "metadata": {}, + "source": [ + "Jasp, a dynamic Pythonic low-level IR\n", + "-------------------------------------\n", + "\n", + "Within this notebook we demonstrate the latest feature of the Jax Integration.\n", + "\n", + "We introduce a Jasp, a new IR that represents hybrid programs embedded into the Jaxpr IR.\n", + "\n", + "Creating a Jasp program is simple:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c770249c-137c-4b45-9263-7f77726571bc", + "metadata": {}, + "outputs": [], + "source": [ + "from qrisp import *\n", + "from qrisp.jasp import *\n", + "from jax import make_jaxpr\n", + "\n", + "\n", + "def main(i):\n", + " qf = QuantumFloat(i)\n", + " h(qf[0])\n", + " cx(qf[0], qf[1])\n", + "\n", + " meas_float = measure(qf)\n", + "\n", + " return meas_float\n", + " \n", + "\n", + "jaspr = make_jaspr(main)(5)\n", + "\n", + "print(jaspr)" + ] + }, + { + "cell_type": "markdown", + "id": "644dbc01-036b-4d33-8ac3-c85a41b0e252", + "metadata": {}, + "source": [ + "Jasp programs can be executed with the Jasp interpreter by calling them like a function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b02aeffb-7e93-4ec0-bca6-7e226fadcfd5", + "metadata": {}, + "outputs": [], + "source": [ + "print(jaspr(5))" + ] + }, + { + "cell_type": "markdown", + "id": "f67a2eaf-0f1a-4d0d-98e5-370cd181c84e", + "metadata": {}, + "source": [ + "A quicker way to do this is to use the ``jaspify`` decorator. This decorator automatically transforms the function into a Jaspr and calls the simulator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27a9feb7-d81a-41fe-9090-fbb23078fd9e", + "metadata": {}, + "outputs": [], + "source": [ + "@jaspify\n", + "def main(i):\n", + " qf = QuantumFloat(i)\n", + " h(qf[0])\n", + " cx(qf[0], qf[1])\n", + "\n", + " meas_float = measure(qf)\n", + "\n", + " return meas_float\n", + "\n", + "print(main(5))" + ] + }, + { + "cell_type": "markdown", + "id": "987eed43-b8ea-4fdd-a63d-65872016e285", + "metadata": {}, + "source": [ + "Jasp programs can be compiled to QIR, which is one of the most popular low-level representations for quantum computers. For that you need Catalyst installed (only on Mac & Linux)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0782f1ef-d361-4bc4-887e-699cdd3309f8", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import catalyst\n", + "except:\n", + " !pip install pennylane-catalyst" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b491510a-a302-4bde-b253-e6012ff2a775", + "metadata": {}, + "outputs": [], + "source": [ + "qir_string = jaspr.to_qir()\n", + "print(qir_string[:2500])" + ] + }, + { + "cell_type": "markdown", + "id": "2223d289-965f-4f91-a508-d07332433978", + "metadata": {}, + "source": [ + "The Qache decorator\n", + "-------------------\n", + "\n", + "One of the most powerful features of this IR is that it is fully dynamic, allowing many functions to be cached and reused. For this we have the ``qache`` decorator. Qached functions are only excutes ones (per calling signature) and otherwise retrieved from cache." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0a49866-6d82-4906-941e-36b646102079", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "@qache\n", + "def inner_function(qv, i):\n", + " cx(qv[0], qv[1])\n", + " h(qv[i])\n", + " # Complicated compilation, that takes a lot of time\n", + " time.sleep(1)\n", + "\n", + "def main(i):\n", + " qv = QuantumFloat(i)\n", + "\n", + " inner_function(qv, 0)\n", + " inner_function(qv, 1)\n", + " inner_function(qv, 2)\n", + "\n", + " return measure(qv)\n", + "\n", + "\n", + "t0 = time.time()\n", + "jaspr = make_jaspr(main)(5)\n", + "print(time.time()- t0)" + ] + }, + { + "cell_type": "markdown", + "id": "a62a49f1-d431-4d17-b6b2-bf17716f60b5", + "metadata": {}, + "source": [ + "If a cached function is called with a different type (classical or quantum) the function will not be retrieved from cache but instead retraced. If called with the same signature, the appropriate implementation will be retrieved from the cache." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ed87e8c-61af-4cc6-8369-d05eeca5ac4c", + "metadata": {}, + "outputs": [], + "source": [ + "@qache\n", + "def inner_function(qv):\n", + " x(qv)\n", + " time.sleep(1)\n", + "\n", + "def main():\n", + " qf = QuantumFloat(5)\n", + " qbl = QuantumBool(5)\n", + "\n", + " inner_function(qf)\n", + " inner_function(qf)\n", + " inner_function(qbl)\n", + " inner_function(qbl)\n", + "\n", + " return measure(qf)\n", + "\n", + "t0 = time.time()\n", + "jaspr = make_jaspr(main)()\n", + "print(time.time()- t0)" + ] + }, + { + "cell_type": "markdown", + "id": "82fcdc22-737e-475a-bc68-1837e1663a5d", + "metadata": {}, + "source": [ + "We see 2 seconds now because the ``inner_function`` has been traced twice: Once for the ``QuantumFloat`` and once for the ``QuantumBool``.\n", + "\n", + "Another important concept are dynamic values. Dynamic values are values that are only known at runtime (i.e. when the program is actually executed). This could be because the value is coming from a quantum measurement. Every QuantumVariable and it's ``.size`` attribute are dynamic. Furthermore classical values can also be dynamic. For classical values, we can use the Python native ``isinstance`` check for the ``jax.core.Tracer`` class, whether a variable is dynamic. Note that even though ``QuantumVariables`` behave dynamic, they are not tracers themselves." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51997b89-b3fc-428c-ba95-d7e7a63014f5", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.core import Tracer\n", + "\n", + "def main(i):\n", + " print(\"i is dynamic?: \", isinstance(i, Tracer))\n", + " \n", + " qf = QuantumFloat(5)\n", + " j = qf.size\n", + " print(\"j is dynamic?: \", isinstance(i, Tracer))\n", + " \n", + " h(qf)\n", + " k = measure(qf)\n", + " print(\"k is dynamic?: \", isinstance(k, Tracer))\n", + "\n", + " # Regular Python integers are not dynamic\n", + " l = 5\n", + " print(\"l is dynamic?: \", isinstance(l, Tracer))\n", + "\n", + " # Arbitrary Python objects can be used within Jasp\n", + " # but they are not dynamic\n", + " import networkx as nx\n", + " G = nx.DiGraph()\n", + " G.add_edge(1,2)\n", + " print(\"G is dynamic?: \", isinstance(l, Tracer))\n", + " \n", + " return k\n", + "\n", + "jaspr = make_jaspr(main)(5)\n" + ] + }, + { + "cell_type": "markdown", + "id": "34f80144-9f8e-4412-82dd-b6049fac1993", + "metadata": {}, + "source": [ + "What is the advantage of dynamic values? Dynamical code is scale invariant! For this we can use the ``jrange`` iterator, which allows you to execute a dynamic amount of loop iterations. Some restrictions apply however (check the docs to see which)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce8e8278-b779-4ee7-b042-b29fe196f304", + "metadata": {}, + "outputs": [], + "source": [ + "@jaspify\n", + "def main(k):\n", + "\n", + " a = QuantumFloat(k)\n", + " b = QuantumFloat(k)\n", + "\n", + " # Brings a into uniform superposition via Hadamard\n", + " h(a)\n", + "\n", + " c = measure(a)\n", + "\n", + " # Excutes c iterations (i.e. depending the measurement outcome)\n", + " for i in jrange(c):\n", + "\n", + " # Performs a quantum incrementation on b based on the measurement outcome\n", + " b += c//5\n", + "\n", + " return measure(b)\n", + "\n", + "print(main(5))" + ] + }, + { + "cell_type": "markdown", + "id": "801b7f65-1445-4820-90c1-91669a76ebf1", + "metadata": {}, + "source": [ + "It is possible to execute a multi-controlled X gate with a dynamic amount of controls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2516cbf8-e5e8-4a38-825e-2ee64f73ea55", + "metadata": {}, + "outputs": [], + "source": [ + "@jaspify\n", + "def main(i, j, k):\n", + "\n", + " a = QuantumFloat(5)\n", + " a[:] = i\n", + " \n", + " qbl = QuantumBool()\n", + "\n", + " # a[:j] is a dynamic amount of controls\n", + " mcx(a[:j], qbl[0], ctrl_state = k)\n", + "\n", + " return measure(qbl)" + ] + }, + { + "cell_type": "markdown", + "id": "cf7c6987-9aaa-496d-a2ac-6836685d9a38", + "metadata": {}, + "source": [ + "This function encodes the integer ``i`` into a ``QuantumFloat`` and subsequently performs an MCX gate with control state ``k``. Therefore, we expect the function to return ``True`` if ``i == k`` and ``j > 5``." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0d41b31-fa4e-4299-b850-ac3303a3aa01", + "metadata": {}, + "outputs": [], + "source": [ + "print(main(1, 6, 1))\n", + "print(main(3, 6, 1))\n", + "print(main(2, 1, 1))" + ] + }, + { + "cell_type": "markdown", + "id": "19fbf3f3-1979-4d58-9e9f-fe36669bb21e", + "metadata": {}, + "source": [ + "Classical control flow\n", + "----------------------\n", + "\n", + "Jasp code can be conditioned on classically known values. For that we simply use the ``control`` feature from base-Qrisp but with dynamical, classical bools. Some restrictions apply (check the docs for more details)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c721b7f-50a7-4442-88a5-38ad99976c42", + "metadata": {}, + "outputs": [], + "source": [ + "@jaspify\n", + "def main():\n", + "\n", + " qf = QuantumFloat(3)\n", + " h(qf)\n", + "\n", + " # This is a classical, dynamical int\n", + " meas_res = measure(qf)\n", + "\n", + " # This is a classical, dynamical bool\n", + " ctrl_bl = meas_res >= 4\n", + " \n", + " with control(ctrl_bl):\n", + " qf -= 4\n", + "\n", + " return measure(qf)\n", + "\n", + "for i in range(5):\n", + " print(main())" + ] + }, + { + "cell_type": "markdown", + "id": "a638411b-b6c7-4ed6-b07c-6d03636d90bd", + "metadata": {}, + "source": [ + "The RUS decorator\n", + "-----------------\n", + "\n", + "RUS stands for Repeat-Until-Success and is an essential part for many quantum algorithms such as HHL or LCU. As the name says the RUS component repeats a certain subroutine until a measurement yields ``True``. The RUS decorator should be applied to a ``trial_function``, which returns a classical bool as the first return value and some arbitrary other values. The trial function will be repeated until the classical bool is ``True``.\n", + "\n", + "To demonstrate the RUS behavior, we initialize a GHZ state \n", + "\n", + "$\\ket{\\psi} = \\frac{1}{\\sqrt{2}} (\\ket{00000} + \\ket{11111})$\n", + "\n", + "and measure the first qubit into a boolean value. This will be the value to cancel the repetition. This will collapse the GHZ state into either $\\ket{00000}$ (which will cause a new repetition) or $\\ket{11111} = \\ket{31}$, which cancels the loop. After the repetition is canceled we are therefore guaranteed to have the latter state.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09fd5295-68a9-4de0-aed6-bf17d44ea95a", + "metadata": {}, + "outputs": [], + "source": [ + "from qrisp.jasp import RUS, make_jaspr\n", + "from qrisp import QuantumFloat, h, cx, measure\n", + "\n", + "def init_GHZ(qf):\n", + " h(qf[0])\n", + " for i in jrange(1, qf.size):\n", + " cx(qf[0], qf[i])\n", + "\n", + "@RUS\n", + "def rus_trial_function():\n", + " qf = QuantumFloat(5)\n", + "\n", + " init_GHZ(qf)\n", + " \n", + " cancelation_bool = measure(qf[0])\n", + " \n", + " return cancelation_bool, qf\n", + "\n", + "@jaspify\n", + "def main():\n", + "\n", + " qf = rus_trial_function()\n", + "\n", + " return measure(qf)\n", + "\n", + "print(main())" + ] + }, + { + "cell_type": "markdown", + "id": "236354d8-3d21-4efc-b640-ffaa4eab8a1e", + "metadata": {}, + "source": [ + "Terminal sampling\n", + "-----------------\n", + "\n", + "The ``jaspify`` decorator executes one \"shot\". For many quantum algorithms we however need the distribution of shots. In principle we could execute a bunch of \"jaspified\" function calls, which is however not as scalable. For this situation we have the ``terminal_sampling`` decorator. To use this decorator we need a function that returns a ``QuantumVariable`` (instead of a classical measurement result). The decorator will then perform a (hybrid) simulation of the given script and subsequently sample from the distribution at the end." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "851a0a3f-64ef-453d-8bab-e6e6b7d8ffe8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "@RUS\n", + "def rus_trial_function():\n", + " qf = QuantumFloat(5)\n", + "\n", + " init_GHZ(qf)\n", + " \n", + " cancelation_bool = measure(qf[0])\n", + " \n", + " return cancelation_bool, qf\n", + "\n", + "@terminal_sampling\n", + "def main():\n", + "\n", + " qf = rus_trial_function()\n", + " h(qf[0])\n", + "\n", + " return qf\n", + "\n", + "print(main())" + ] + }, + { + "cell_type": "markdown", + "id": "25f96fde-2bb4-410e-95c6-28ac6c1173c4", + "metadata": {}, + "source": [ + "The ``terminal_sampling`` decorator requires some care however. Remember that it only samples from the distribution at the end of the algorithm. This distribution can depend on random chances that happened during the execution. We demonstrate faulty use in the following example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad3c678d-b524-4784-9cde-04966897c7b0", + "metadata": {}, + "outputs": [], + "source": [ + "from qrisp import QuantumBool, measure, control\n", + "\n", + "@terminal_sampling\n", + "def main():\n", + "\n", + " qbl = QuantumBool()\n", + " qf = QuantumFloat(4)\n", + "\n", + " # Bring qbl into superposition\n", + " h(qbl)\n", + "\n", + " # Perform a measure\n", + " cl_bl = measure(qbl)\n", + "\n", + " # Perform a conditional operation based on the measurement outcome\n", + " with control(cl_bl):\n", + " qf[:] = 1\n", + " h(qf[2])\n", + "\n", + " return qf\n", + "\n", + "for i in range(5):\n", + " print(main())\n", + "# Yields either {0.0: 1.0} or {1.0: 0.5, 5.0: 0.5} (with a 50/50 probability)" + ] + }, + { + "cell_type": "markdown", + "id": "54c72f4d-80dc-4131-b5c9-456a004ff2ef", + "metadata": {}, + "source": [ + "Boolean simulation\n", + "------------------\n", + "\n", + "The tight Jax integration of Jasp enables some powerful features such as a highly performant simulator of purely boolean circuits. This simulator works by transforming Jaspr objects that contain only X, CX, MCX etc. into boolean Jax logic. Subsequently this is inserted into the Jax pipeline, which yields a highly scalable simulator for purely classical Jasp functions.\n", + "\n", + "To call this simulator, we simply use the ``boolean_simulation`` decorator like we did with the ``jaspify`` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d47dfe1f-6d04-4e1e-9b4a-05faf20da12c", + "metadata": {}, + "outputs": [], + "source": [ + "from qrisp import *\n", + "from qrisp.jasp import *\n", + "\n", + "def quantum_mult(a, b):\n", + " return a*b\n", + "\n", + "@boolean_simulation(bit_array_padding = 2**10)\n", + "def main(i, j, iterations):\n", + "\n", + " a = QuantumFloat(10)\n", + " b = QuantumFloat(10)\n", + "\n", + " a[:] = i\n", + " b[:] = j\n", + "\n", + " c = QuantumFloat(30)\n", + "\n", + " for i in jrange(iterations): \n", + "\n", + " # Compute the quantum product\n", + " temp = quantum_mult(a,b)\n", + "\n", + " # add into c\n", + " c += temp\n", + "\n", + " # Uncompute the quantum product\n", + " with invert():\n", + " # The << operator \"injects\" the quantum variable into\n", + " # the function. This means that the quantum_mult\n", + " # function, which was originally out-of-place, is\n", + " # now an in-place function operating on temp.\n", + "\n", + " # It can therefore be used for uncomputation\n", + " # Automatic uncomputation is not yet available within Jasp.\n", + " (temp << quantum_mult)(a, b)\n", + "\n", + " # Delete temp\n", + " temp.delete()\n", + "\n", + " return measure(c)\n" + ] + }, + { + "cell_type": "markdown", + "id": "e059a479-dcdc-482c-8729-9fddc518b72e", + "metadata": {}, + "source": [ + "The first call needs some time for compilation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2793d8d1-73df-4d9b-bfe0-de4ca15953fe", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "t0 = time.time()\n", + "main(1, 2, 5)\n", + "print(time.time()-t0)" + ] + }, + { + "cell_type": "markdown", + "id": "818a9d8d-fcf5-4fdb-af32-6f6d03714de6", + "metadata": {}, + "source": [ + "Any subsequent call is super fast" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e197d7f3-c954-400c-b4f6-98a515106865", + "metadata": {}, + "outputs": [], + "source": [ + "t0 = time.time()\n", + "print(main(3, 4, 120)) # Expected to be 3*4*120 = 1440\n", + "print(f\"Took {time.time()-t0} to simulate 120 iterations\")" + ] + }, + { + "cell_type": "markdown", + "id": "176243ba-c0a7-4a17-bad5-9b52744507e2", + "metadata": {}, + "source": [ + "Compile and simulate A MILLION QFLOPs!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93a7e489-ea30-43ce-a195-5d56b4a666c8", + "metadata": {}, + "outputs": [], + "source": [ + "print(main(532, 233, 1000000))" + ] + }, + { + "cell_type": "markdown", + "id": "f6e97d60-b1a5-49dd-8698-c9c9996b4346", + "metadata": {}, + "source": [ + "Letting a classical, neural network decide when to stop\n", + "-------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "342f1bdc-4b9c-408b-b267-fee3e02b919f", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "from jax import grad, jit\n", + "import optax\n", + "\n", + "# Define the model\n", + "def model(params, x):\n", + " W, b = params\n", + " return jax.nn.sigmoid(jnp.dot(x, W) + b)\n", + "\n", + "# Define the loss function (binary cross-entropy)\n", + "def loss_fn(params, x, y):\n", + " preds = model(params, x)\n", + " return -jnp.mean(y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds))\n", + "\n", + "# Initialize parameters\n", + "key = jax.random.PRNGKey(0)\n", + "W = jax.random.normal(key, (2, 1))\n", + "b = jax.random.normal(key, (1,))\n", + "params = (W, b)\n", + "\n", + "# Create optimizer\n", + "optimizer = optax.adam(learning_rate=0.01)\n", + "opt_state = optimizer.init(params)\n", + "\n", + "# Define training step\n", + "@jit\n", + "def train_step(params, opt_state, x, y):\n", + " loss, grads = jax.value_and_grad(loss_fn)(params, x, y)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, opt_state, loss\n", + "\n", + "# Generate some dummy data\n", + "key = jax.random.PRNGKey(0)\n", + "X = jax.random.normal(key, (1000, 2))\n", + "y = jnp.sum(X > 0, axis=1) % 2\n", + "\n", + "# Training loop\n", + "for epoch in range(100):\n", + " params, opt_state, loss = train_step(params, opt_state, X, y)\n", + " if epoch % 10 == 0:\n", + " print(f\"Epoch {epoch}, Loss: {loss}\")\n", + "\n", + "# Make predictions\n", + "predictions = model(params, X)\n", + "accuracy = jnp.mean((predictions > 0.5) == y)\n", + "print(f\"Final accuracy: {accuracy}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/qrisp/jasp/interpreter_tools/abstract_interpreter.py b/src/qrisp/jasp/interpreter_tools/abstract_interpreter.py index fa1d2c14..0f57e5c3 100644 --- a/src/qrisp/jasp/interpreter_tools/abstract_interpreter.py +++ b/src/qrisp/jasp/interpreter_tools/abstract_interpreter.py @@ -67,6 +67,8 @@ def jaxpr_evaluator(*args): temp_var_list = jaxpr.invars + jaxpr.constvars if len(temp_var_list) != len(args): + print(args) + print(temp_var_list) raise Exception("Tried to evaluate jaxpr with insufficient arguments") context_dic = ContextDict({temp_var_list[i] : args[i] for i in range(len(args))}) diff --git a/src/qrisp/jasp/jasp_expression/centerclass.py b/src/qrisp/jasp/jasp_expression/centerclass.py index 36a02338..bcda2bee 100644 --- a/src/qrisp/jasp/jasp_expression/centerclass.py +++ b/src/qrisp/jasp/jasp_expression/centerclass.py @@ -429,9 +429,10 @@ def __call__(self, *args): if len(self.outvars) == 1: return None + from jax.tree_util import tree_flatten from qrisp.simulator import BufferedQuantumState - args = [BufferedQuantumState()] + list(args) - + args = [BufferedQuantumState()] + list(tree_flatten(args)[0]) + from qrisp.jasp import extract_invalues, insert_outvalues, eval_jaxpr flattened_jaspr = self @@ -1120,10 +1121,22 @@ def jitted_function(*args): return jitted_function +from jax.tree_util import tree_flatten, tree_unflatten def jaspify(func): + + treedef_container = [] + def tracing_function(*args): + res = func(*args) + flattened_values, tree_def = tree_flatten(res) + treedef_container.append(tree_def) + return flattened_values + def return_function(*args): - jaspr = make_jaspr(func)(*args) - return jaspr(*args) + jaspr = make_jaspr(tracing_function)(*args) + jaspr_res = jaspr(*args) + if isinstance(jaspr_res, tuple): + jaspr_res = tree_unflatten(treedef_container[0], jaspr_res) + return jaspr_res return return_function def check_aval_equivalence(invars_1, invars_2):