From dcc6a4832040f768d8b82f1c965ccf41a0f49082 Mon Sep 17 00:00:00 2001 From: positr0nium Date: Thu, 9 Jan 2025 12:43:37 +0100 Subject: [PATCH] updated the jasp_demo notebook --- src/qrisp/examples/jasp_demo.ipynb | 326 +++++++++++++++++++++++++--- tests/jax_tests/test_rus - Kopie.py | 42 ++++ 2 files changed, 336 insertions(+), 32 deletions(-) create mode 100644 tests/jax_tests/test_rus - Kopie.py diff --git a/src/qrisp/examples/jasp_demo.ipynb b/src/qrisp/examples/jasp_demo.ipynb index 00818003..b2b383d4 100644 --- a/src/qrisp/examples/jasp_demo.ipynb +++ b/src/qrisp/examples/jasp_demo.ipynb @@ -2,10 +2,28 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "c2001150-1fef-4784-b47e-4ab75a18d68d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting git+https://github.com/eclipse-qrisp/Qrisp.git@catalyst_integration\n", + " Cloning https://github.com/eclipse-qrisp/Qrisp.git (to revision catalyst_integration) to /tmp/pip-req-build-s03df8r3\n", + " Running command git clone --filter=blob:none --quiet https://github.com/eclipse-qrisp/Qrisp.git /tmp/pip-req-build-s03df8r3\n", + " error: RPC failed; curl 92 HTTP/2 stream 0 was not closed cleanly: CANCEL (err 8)\n", + " error: 7516 bytes of body are still expected\n", + " fetch-pack: unexpected disconnect while reading sideband packet\n", + " fatal: early EOF\n", + " fatal: fetch-pack: invalid index-pack output\n", + "^C\n", + "\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], "source": [ "!pip install git+https://github.com/eclipse-qrisp/Qrisp.git@catalyst_integration" ] @@ -27,10 +45,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "c770249c-137c-4b45-9263-7f77726571bc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ lambda ; a:QuantumCircuit b:i64[]. let\n", + " c:QuantumCircuit d:QubitArray = jasp.create_qubits a b\n", + " e:Qubit = jasp.get_qubit d 0\n", + " f:QuantumCircuit = jasp.h c e\n", + " g:Qubit = jasp.get_qubit d 1\n", + " h:QuantumCircuit = jasp.cx f e g\n", + " i:QuantumCircuit j:i64[] = jasp.measure h d\n", + " k:QuantumCircuit = jasp.reset i d\n", + " l:QuantumCircuit = jasp.delete_qubits k d\n", + " in (l, j) }\n" + ] + } + ], "source": [ "from qrisp import *\n", "from qrisp.jasp import *\n", @@ -62,10 +97,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "b02aeffb-7e93-4ec0-bca6-7e226fadcfd5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 \u001b[2K\n" + ] + } + ], "source": [ "print(jaspr(5))" ] @@ -80,10 +130,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "27a9feb7-d81a-41fe-9090-fbb23078fd9e", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 \u001b[2K\n" + ] + } + ], "source": [ "@jaspify\n", "def main(i):\n", @@ -108,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "0782f1ef-d361-4bc4-887e-699cdd3309f8", "metadata": {}, "outputs": [], @@ -121,10 +179,71 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "b491510a-a302-4bde-b253-e6012ff2a775", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "; ModuleID = 'LLVMDialectModule'\n", + "source_filename = \"LLVMDialectModule\"\n", + "target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128\"\n", + "target triple = \"x86_64-unknown-linux-gnu\"\n", + "\n", + "@\"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\" = internal constant [66 x i8] c\"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\\00\"\n", + "@LightningSimulator = internal constant [19 x i8] c\"LightningSimulator\\00\"\n", + "@\"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so\" = internal constant [111 x i8] c\"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so\\00\"\n", + "\n", + "declare void @__catalyst__rt__finalize()\n", + "\n", + "declare void @__catalyst__rt__initialize(ptr)\n", + "\n", + "declare void @__catalyst__qis__PauliX(ptr, ptr)\n", + "\n", + "declare ptr @__catalyst__qis__Measure(ptr, i32)\n", + "\n", + "declare void @__catalyst__qis__CNOT(ptr, ptr, ptr)\n", + "\n", + "declare void @__catalyst__qis__Hadamard(ptr, ptr)\n", + "\n", + "declare ptr @__catalyst__rt__array_get_element_ptr_1d(ptr, i64)\n", + "\n", + "declare ptr @__catalyst__rt__qubit_allocate_array(i64)\n", + "\n", + "declare void @__catalyst__rt__device_init(ptr, ptr, ptr)\n", + "\n", + "declare ptr @_mlir_memref_to_llvm_alloc(i64)\n", + "\n", + "define { ptr, ptr, i64 } @jit_jaspr_function(ptr %0, ptr %1, i64 %2) {\n", + " %4 = load i64, ptr %1, align 4\n", + " call void @__catalyst__rt__device_init(ptr @\"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so\", ptr @LightningSimulator, ptr @\"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\")\n", + " %5 = call ptr @__catalyst__rt__qubit_allocate_array(i64 20)\n", + " %6 = call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %5, i64 0)\n", + " %7 = load ptr, ptr %6, align 8\n", + " call void @__catalyst__qis__Hadamard(ptr %7, ptr null)\n", + " %8 = call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %5, i64 0)\n", + " %9 = load ptr, ptr %8, align 8\n", + " %10 = call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %5, i64 1)\n", + " %11 = load ptr, ptr %10, align 8\n", + " call void @__catalyst__qis__CNOT(ptr %9, ptr %11, ptr null)\n", + " br label %12\n", + "\n", + "12: ; preds = %17, %3\n", + " %13 = phi i64 [ %31, %17 ], [ 0, %3 ]\n", + " %14 = phi i64 [ %30, %17 ], [ 0, %3 ]\n", + " %15 = phi ptr [ %20, %17 ], [ %5, %3 ]\n", + " %16 = icmp slt i64 %13, %4\n", + " br i1 %16, label %17, label %32\n", + "\n", + "17: ; preds = %12\n", + " %18 = phi i64 [ %13, %12 ]\n", + " %19 = phi i64 [ %14, %12 ]\n", + " %\n" + ] + } + ], "source": [ "qir_string = jaspr.to_qir()\n", "print(qir_string[:2500])" @@ -143,10 +262,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "b0a49866-6d82-4906-941e-36b646102079", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0149471759796143\n" + ] + } + ], "source": [ "import time\n", "\n", @@ -182,10 +309,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "0ed87e8c-61af-4cc6-8369-d05eeca5ac4c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.033618688583374\n" + ] + } + ], "source": [ "@qache\n", "def inner_function(qv):\n", @@ -220,10 +355,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "51997b89-b3fc-428c-ba95-d7e7a63014f5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "i is dynamic?: True\n", + "j is dynamic?: True\n", + "k is dynamic?: True\n", + "l is dynamic?: False\n", + "G is dynamic?: False\n" + ] + } + ], "source": [ "from jax.core import Tracer\n", "\n", @@ -264,10 +411,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "ce8e8278-b779-4ee7-b042-b29fe196f304", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 \u001b[2K\n" + ] + } + ], "source": [ "@jaspify\n", "def main(k):\n", @@ -301,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "2516cbf8-e5e8-4a38-825e-2ee64f73ea55", "metadata": {}, "outputs": [], @@ -330,10 +485,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "e0d41b31-fa4e-4299-b850-ac3303a3aa01", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True \u001b[2K\n", + "False \u001b[2K\n", + "False \u001b[2K\n" + ] + } + ], "source": [ "print(main(1, 6, 1))\n", "print(main(3, 6, 1))\n", @@ -353,10 +518,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "0c721b7f-50a7-4442-88a5-38ad99976c42", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 \u001b[2K\n", + "0 \u001b[2K\u001b[2K\n", + "0 \u001b[2K\n", + "0 \u001b[2K\n", + "3 \u001b[2K\n" + ] + } + ], "source": [ "@jaspify\n", "def main():\n", @@ -398,10 +575,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "09fd5295-68a9-4de0-aed6-bf17d44ea95a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "31.0 \u001b[2K\n" + ] + } + ], "source": [ "from qrisp.jasp import RUS, make_jaspr\n", "from qrisp import QuantumFloat, h, cx, measure\n", @@ -444,10 +629,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "851a0a3f-64ef-453d-8bab-e6e6b7d8ffe8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{30.0: 0.5, 31.0: 0.5} \u001b[2K\n" + ] + } + ], "source": [ "\n", "@RUS\n", @@ -481,10 +674,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "ad3c678d-b524-4784-9cde-04966897c7b0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{0.0: 1.0} \u001b[2K\n", + "{1.0: 0.5, 5.0: 0.5} \u001b[2K\n", + "{1.0: 0.5, 5.0: 0.5} \u001b[2K\n", + "{1.0: 0.5, 5.0: 0.5} \u001b[2K\n", + "{1.0: 0.5, 5.0: 0.5} \u001b[2K\n" + ] + } + ], "source": [ "from qrisp import QuantumBool, measure, control\n", "\n", @@ -527,7 +732,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "d47dfe1f-6d04-4e1e-9b4a-05faf20da12c", "metadata": {}, "outputs": [], @@ -639,7 +844,9 @@ "metadata": {}, "source": [ "Letting a classical, neural network decide when to stop\n", - "-------------------------------------------------------" + "-------------------------------------------------------\n", + "\n", + "The following example showcases how a simple neural network can decide (in real-time) whether to go on or break the RUS iteration. For that we create a simple binary classifier and train it on dummy data (disclaimer: ML code by ChatGPT). This is code is not really useful in anyway and the classifier is classifying random data, but it shows how such an algorithm can be constructed and evaluated." ] }, { @@ -649,6 +856,7 @@ "metadata": {}, "outputs": [], "source": [ + "import jax\n", "import jax.numpy as jnp\n", "from jax import grad, jit\n", "import optax\n", @@ -697,6 +905,60 @@ "accuracy = jnp.mean((predictions > 0.5) == y)\n", "print(f\"Final accuracy: {accuracy}\")\n" ] + }, + { + "cell_type": "markdown", + "id": "c40e15a9-7e63-44c2-b554-842eb846be96", + "metadata": {}, + "source": [ + "We can now use the ``model`` function to evaluate the classifier. Since this function is Jax-based it integrates seamlessly into Jasp." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c36ba0fb-b9e6-4b01-bd66-7c06dda7b1af", + "metadata": {}, + "outputs": [], + "source": [ + "from qrisp.jasp import *\n", + "from qrisp import *\n", + " \n", + "@RUS\n", + "def rus_trial_function(params):\n", + "\n", + " # Sample data from two QuantumFloats.\n", + " # This is a placeholder for an arbitrary quantum algorithm.\n", + " qf_0 = QuantumFloat(5)\n", + " h(qf_0)\n", + "\n", + " qf_1 = QuantumFloat(5)\n", + " h(qf_1)\n", + "\n", + " meas_res_0 = measure(qf_0)\n", + " meas_res_1 = measure(qf_1)\n", + "\n", + " # Turn the data into a Jax array\n", + " X = jnp.array([meas_res_0,meas_res_1])/2**qf_0.size\n", + "\n", + " # Evaluate the model\n", + " model_res = model(params, X)\n", + "\n", + " # Determine the cancelation\n", + " cancelation_bool = (model_res > 0.5)[0]\n", + " \n", + " return cancelation_bool, qf_0\n", + "\n", + "@jaspify\n", + "def main(params):\n", + "\n", + " qf = rus_trial_function(params)\n", + " h(qf[0])\n", + "\n", + " return measure(qf)\n", + "\n", + "print(main(params))" + ] } ], "metadata": { diff --git a/tests/jax_tests/test_rus - Kopie.py b/tests/jax_tests/test_rus - Kopie.py new file mode 100644 index 00000000..9745f7e8 --- /dev/null +++ b/tests/jax_tests/test_rus - Kopie.py @@ -0,0 +1,42 @@ +""" +\******************************************************************************** +* Copyright (c) 2023 the Qrisp authors +* +* This program and the accompanying materials are made available under the +* terms of the Eclipse Public License 2.0 which is available at +* http://www.eclipse.org/legal/epl-2.0. +* +* This Source Code may also be made available under the following Secondary +* Licenses when the conditions for such availability set forth in the Eclipse +* Public License, v. 2.0 are satisfied: GNU General Public License, version 2 +* with the GNU Classpath Exception which is +* available at https://www.gnu.org/software/classpath/license.html. +* +* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 +********************************************************************************/ +""" + +from qrisp import * +from qrisp.jasp import * +from jax import make_jaxpr + +def test_injection_operator(): + + @jaspify + def main(i): + + a = QuantumFloat(i) + b = QuantumFloat(i) + + h(a) + h(b) + + s = a*b + + with invert(): + (s << (lambda a, b : a*b))(a,b) + + return measure(s) + + for i in range(2, 6): + assert main(i) == 0