diff --git a/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py b/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py new file mode 100644 index 000000000..75ce158fb --- /dev/null +++ b/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py @@ -0,0 +1,350 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from http import HTTPStatus +from io import BytesIO +import pathlib +from tempfile import SpooledTemporaryFile +from typing import Mapping, Optional +from celery import chain +import celery +from celery.utils.log import get_task_logger +from flask import abort, redirect, send_file +from flask.app import Flask +from flask.globals import request +from flask.helpers import url_for +from flask.templating import render_template +from flask.views import MethodView +from flask.wrappers import Response +from marshmallow import EXCLUDE +from requests.exceptions import HTTPError + +from qhana_plugin_runner.api.plugin_schemas import ( + DataMetadata, + EntryPoint, + InputDataMetadata, + PluginMetadata, + PluginMetadataSchema, + PluginType, +) +from qhana_plugin_runner.api.util import ( + FileUrl, + FrontendFormBaseSchema, + SecurityBlueprint, +) +from qhana_plugin_runner.celery import CELERY +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.requests import open_url +from qhana_plugin_runner.storage import STORE +from qhana_plugin_runner.tasks import save_task_error, save_task_result +from qhana_plugin_runner.util.plugins import QHAnaPluginBase, plugin_identifier +from qhana_plugin_runner.db.models.virtual_plugins import DataBlob, PluginState + +_plugin_name = "cluster-scatter-visualization" +__version__ = "v0.0.1" +_identifier = plugin_identifier(_plugin_name, __version__) + + +VIS_BLP = SecurityBlueprint( + _identifier, # blueprint name + __name__, # module import name! + description="A visualization plugin for cluster scatter data.", + template_folder="cluster_scatter_visualization_templates", +) + + +class CSInputParametersSchema(FrontendFormBaseSchema): + entity_url = FileUrl( + required=True, + allow_none=False, + data_input_type="entity/vector", + data_content_types=["application/json"], + metadata={ + "label": "Entity Point URL", + "description": "URL to a json file containing the points.", + "input_type": "text", + }, + ) + clusters_url = FileUrl( + required=True, + allow_none=True, + data_input_type="entity/vector", + data_content_types=["application/json"], + metadata={ + "label": "Cluster URL", + "description": "URL to a json file containing the cluster labels.", + "input_type": "text", + }, + ) + + +@VIS_BLP.route("/") +class PluginsView(MethodView): + """Plugins collection resource.""" + + @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema()) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self): + """Endpoint returning the plugin metadata.""" + plugin = CSVisualization.instance + if plugin is None: + abort(HTTPStatus.INTERNAL_SERVER_ERROR) + return PluginMetadata( + title=plugin.name, + description=plugin.description, + name=plugin.name, + version=plugin.version, + type=PluginType.visualization, + entry_point=EntryPoint( + href=url_for(f"{VIS_BLP.name}.ProcessView"), + ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"), + plugin_dependencies=[], + data_input=[ + InputDataMetadata( + data_type="entity/vector", + content_type=["application/json"], + required=True, + parameter="entityUrl", + ), + InputDataMetadata( + data_type="entity/label", + content_type=["application/json"], + required=True, + parameter="clustersUrl", + ) + ], + data_output=[ + DataMetadata( + data_type="plot", + content_type=["image/svg+xml"], + required=True, + ) + ], + ), + tags=["visualization", "cluster", "scatter"], + ) + + +@VIS_BLP.route("/ui/") +class MicroFrontend(MethodView): + """Micro frontend for the cluster scatter visualization plugin.""" + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend of the cluster scatter visualization plugin." + ) + @VIS_BLP.arguments( + CSInputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="query", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self, errors): + """Return the micro frontend.""" + return self.render(request.args, errors, False) + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend of the cluster scatter visualization plugin." + ) + @VIS_BLP.arguments( + CSInputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="form", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, errors): + """Return the micro frontend with prerendered inputs.""" + return self.render(request.form, errors, not errors) + + def render(self, data: Mapping, errors: dict, valid: bool): + plugin = CSVisualization.instance + if plugin is None: + abort(HTTPStatus.INTERNAL_SERVER_ERROR) + return Response( + render_template( + "cluster_scatter_visualization.html", + name=plugin.name, + version=plugin.version, + schema=CSInputParametersSchema(), + valid=valid, + values=data, + errors=errors, + example_values=url_for(f"{VIS_BLP.name}.MicroFrontend"), + get_circuit_image_url=url_for(f"{VIS_BLP.name}.get_circuit_image"), + process=url_for(f"{VIS_BLP.name}.ProcessView"), + ) + ) + + +class ImageNotFinishedError(Exception): + pass + + +@VIS_BLP.route("/circuits/") +@VIS_BLP.response(HTTPStatus.OK, description="Circuit image.") +@VIS_BLP.arguments( + CSInputParametersSchema(unknown=EXCLUDE), + location="query", + required=True, +) +@VIS_BLP.require_jwt("jwt", optional=True) +def get_circuit_image(data: Mapping): + url = data.get("data", None) + if not url: + abort(HTTPStatus.BAD_REQUEST) + url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest() + image = DataBlob.get_value(CSVisualization.instance.identifier, url_hash, None) + if image is None: + if not ( + task_id := PluginState.get_value( + CSVisualization.instance.identifier, url_hash, None + ) + ): + task_result = generate_image.s(url, url_hash).apply_async() + PluginState.set_value( + CSVisualization.instance.identifier, + url_hash, + task_result.id, + commit=True, + ) + else: + task_result = CELERY.AsyncResult(task_id) + try: + task_result.get(timeout=5) + image = DataBlob.get_value(CSVisualization.instance.identifier, url_hash) + except celery.exceptions.TimeoutError: + return Response("Image not yet created!", HTTPStatus.ACCEPTED) + if not image: + abort(HTTPStatus.BAD_REQUEST, "Invalid circuit URL!") + return send_file(BytesIO(image), mimetype="image/svg+xml") + + +@VIS_BLP.route("/process/") +class ProcessView(MethodView): + """Start a long running processing task.""" + + @VIS_BLP.arguments(CSInputParametersSchema(unknown=EXCLUDE), location="form") + @VIS_BLP.response(HTTPStatus.SEE_OTHER) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, arguments): + circuit_url = arguments.get("data", None) + if circuit_url is None: + abort(HTTPStatus.BAD_REQUEST) + url_hash = hashlib.sha256(circuit_url.encode("utf-8")).hexdigest() + db_task = ProcessingTask(task_name=process.name) + db_task.save(commit=True) + + # all tasks need to know about db id to load the db entry + task: chain = process.s( + db_id=db_task.id, url=circuit_url, hash=url_hash + ) | save_task_result.s(db_id=db_task.id) + # save errors to db + task.link_error(save_task_error.s(db_id=db_task.id)) + task.apply_async(db_id=db_task.id, url=circuit_url, hash=url_hash) + + db_task.save(commit=True) + + return redirect( + url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER + ) + + +class CSVisualization(QHAnaPluginBase): + name = _plugin_name + version = __version__ + description = "Visualizes cluster data in a scatter plot." + tags = ["visualization"] + + def __init__(self, app: Optional[Flask]) -> None: + super().__init__(app) + + # create folder for circuit images + pathlib.Path(__file__).parent.absolute().joinpath("img").mkdir( + parents=True, exist_ok=True + ) + + def get_api_blueprint(self): + return VIS_BLP + + def get_requirements(self) -> str: + return "pylatexenc~=2.10\nqiskit~=0.43" + + +TASK_LOGGER = get_task_logger(__name__) + + +@CELERY.task(name=f"{CSVisualization.instance.identifier}.generate_image", bind=True) +def generate_image(self, url: str, hash: str) -> str: + from qiskit import QuantumCircuit + import matplotlib + + matplotlib.use("SVG") + + TASK_LOGGER.info(f"Generating image for circuit {url}...") + try: + with open_url(url) as qasm_response: + circuit_qasm = qasm_response.text + except HTTPError: + TASK_LOGGER.error(f"Invalid circuit URL: {url}") + DataBlob.set_value( + CSVisualization.instance.identifier, + hash, + "", + ) + PluginState.delete_value(CSVisualization.instance.identifier, hash, commit=True) + return "Invalid circuit URL!" + + circuit = QuantumCircuit.from_qasm_str(circuit_qasm) + fig = circuit.draw(output="mpl", interactive=False) + figfile = BytesIO() + fig.savefig(figfile, format="svg") + figfile.seek(0) + DataBlob.set_value(CSVisualization.instance.identifier, hash, figfile.getvalue()) + TASK_LOGGER.info(f"Stored image of circuit {circuit.name}.") + PluginState.delete_value(CSVisualization.instance.identifier, hash, commit=True) + + return "Created image of circuit!" + + +@CELERY.task( + name=f"{CSVisualization.instance.identifier}.process", + bind=True, + autoretry_for=(ImageNotFinishedError,), + retry_backoff=True, + max_retries=None, +) +def process(self, db_id: str, url: str, hash: str) -> str: + if not (image := DataBlob.get_value(CSVisualization.instance.identifier, hash)): + if not ( + task_id := PluginState.get_value(CSVisualization.instance.identifier, hash) + ): + task_result = generate_image.s(url, hash).apply_async() + PluginState.set_value( + CSVisualization.instance.identifier, + hash, + task_result.id, + commit=True, + ) + raise ImageNotFinishedError() + with SpooledTemporaryFile() as output: + output.write(image) + output.seek(0) + STORE.persist_task_result( + db_id, output, f"circuit_{hash}.svg", "image/svg", "image/svg+xml" + ) + return "Created image of circuit!" diff --git a/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html b/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html new file mode 100644 index 000000000..5a754fe20 --- /dev/null +++ b/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html @@ -0,0 +1,126 @@ +{% extends "simple_template.html" %} + +{% block head %} +{{ super() }} + +{% endblock head %} + +{% block content %} +
+
+ Visualization Options + {% call forms.render_form(target="microfrontend") %} + {{ forms.render_fields(schema, values=values, errors=errors) }} +
+ {{ forms.submit("validate", target="microfrontend") }} + {{ forms.submit("submit", action=process) }} +
+ {% endcall %} +
+
+ + +
+
+
+{% endblock content %} + +{% block script %} +{{ super() }} + +{% endblock script %} \ No newline at end of file diff --git a/plugins/cluster_svm_visualization/tasks.py b/plugins/cluster_svm_visualization/tasks.py index 5190cf6c2..ebffbc1b4 100644 --- a/plugins/cluster_svm_visualization/tasks.py +++ b/plugins/cluster_svm_visualization/tasks.py @@ -66,6 +66,7 @@ def visualization_task(self, db_id: int) -> str: pt_x_list = [0 for _ in range(0, len(entity_points))] pt_y_list = [0 for _ in range(0, len(entity_points))] + pt_z_list = [0 for _ in range(0, len(entity_points))] label_list = [0 for _ in range(0, len(entity_points))] id_list = [x for x in range(0, len(entity_points))] size_list = [10 for _ in range(0, len(entity_points))] @@ -75,6 +76,8 @@ def visualization_task(self, db_id: int) -> str: idx = int(pt["ID"]) pt_x_list[idx] = pt["dim0"] pt_y_list[idx] = pt["dim1"] + if do_3d: + pt_z_list[idx] = pt["dim2"] for cl in clusters: label_list[int(cl["ID"])] = cl["label"] @@ -85,20 +88,33 @@ def visualization_task(self, db_id: int) -> str: "ID": [f"Point {x}" for x in id_list], "x": pt_x_list, "y": pt_y_list, + "z": pt_z_list, "Cluster ID": [str(x) for x in label_list], "size": size_list, } ) - fig = px.scatter( - df, - x="x", - y="y", - size="size", - hover_name="ID", - color="Cluster ID", - hover_data={"size": False}, - ) + if not do_3d: + fig = px.scatter( + df, + x="x", + y="y", + size="size", + hover_name="ID", + color="Cluster ID", + hover_data={"size": False}, + ) + else: + fig = px.scatter_3d( + df, + x="x", + y="y", + z="z", + size="size", + hover_name="ID", + color="Cluster ID", + hover_data={"size": False}, + ) if do_svm: cluster_list = [[] for _ in range(0, max_cluster + 1)] diff --git a/plugins/zxcalculus/__init__.py b/plugins/zxcalculus/__init__.py new file mode 100644 index 000000000..f083564b2 --- /dev/null +++ b/plugins/zxcalculus/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2022 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from flask import Flask + +from qhana_plugin_runner.api.util import SecurityBlueprint +from qhana_plugin_runner.util.plugins import plugin_identifier, QHAnaPluginBase + +_plugin_name = "zxcalculus" +__version__ = "v0.0.4" +_identifier = plugin_identifier(_plugin_name, __version__) + + +VIS_BLP = SecurityBlueprint( + _identifier, # blueprint name + __name__, # module import name! + description="ZXCalculus API.", +) + + +class ZXCalculus(QHAnaPluginBase): + name = _plugin_name + version = __version__ + description = "Generates a random circuit, visualizes and simplifies it." + tags = ["zxcalculus", "circuit"] + + def __init__(self, app: Optional[Flask]) -> None: + super().__init__(app) + + def get_api_blueprint(self): + return VIS_BLP + + def get_requirements(self) -> str: + return "plotly~=5.18.0\npyzx~=0.8.0\nmpld3~=0.5.10" + + +try: + # It is important to import the routes **after** COSTUME_LOADER_BLP and CostumeLoader are defined, because they are + # accessed as soon as the routes are imported. + from . import routes +except ImportError: + # When running `poetry run flask install`, importing the routes will fail, because the dependencies are not + # installed yet. + pass diff --git a/plugins/zxcalculus/routes.py b/plugins/zxcalculus/routes.py new file mode 100644 index 000000000..003fff128 --- /dev/null +++ b/plugins/zxcalculus/routes.py @@ -0,0 +1,160 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Mapping + +from celery.canvas import chain +from flask import Response, redirect +from flask.globals import request +from flask.helpers import url_for +from flask.templating import render_template +from flask.views import MethodView +from marshmallow import EXCLUDE + +from qhana_plugin_runner.api.plugin_schemas import ( + EntryPoint, + PluginMetadata, + PluginMetadataSchema, + PluginType, + InputDataMetadata, + DataMetadata, +) +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.tasks import ( + save_task_error, + save_task_result, +) + +from . import VIS_BLP, ZXCalculus +from .schemas import InputParametersSchema, TaskResponseSchema +from .tasks import visualization_task + + +@VIS_BLP.route("/") +class PluginsView(MethodView): + """Plugins collection resource.""" + + @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self): + """Endpoint returning the plugin metadata.""" + + return PluginMetadata( + title="ZXCalculus", + description=ZXCalculus.instance.description, + name=ZXCalculus.instance.name, + version=ZXCalculus.instance.version, + type=PluginType.visualization, + entry_point=EntryPoint( + href=url_for(f"{VIS_BLP.name}.ProcessView"), + ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"), + data_input=[], + data_output=[ + DataMetadata( + data_type="circuit", content_type=["text/html"], required=True + ) + ], + ), + tags=ZXCalculus.instance.tags, + ) + + +@VIS_BLP.route("/ui/") +class MicroFrontend(MethodView): + """Micro frontend for the ZXCalculus plugin.""" + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for the ZXCalculus plugin." + ) + @VIS_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="query", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self, errors): + """Return the micro frontend.""" + return self.render(request.args, errors) + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for ZXCalculus plugin." + ) + @VIS_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="form", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, errors): + """Return the micro frontend with prerendered inputs.""" + return self.render(request.form, errors) + + def render(self, data: Mapping, errors: dict): + schema = InputParametersSchema() + + data_dict = dict(data) + + # define default values + default_values = {} + + # overwrite default values with other values if possible + default_values.update(data_dict) + data_dict = default_values + + return Response( + render_template( + "simple_template.html", + name=ZXCalculus.instance.name, + version=ZXCalculus.instance.version, + schema=schema, + values=data_dict, + errors=errors, + process=url_for(f"{VIS_BLP.name}.ProcessView"), + ) + ) + + +@VIS_BLP.route("/process/") +class ProcessView(MethodView): + """Start a long running processing task.""" + + @VIS_BLP.arguments(InputParametersSchema(unknown=EXCLUDE), location="form") + @VIS_BLP.response(HTTPStatus.OK, TaskResponseSchema()) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, arguments): + """Start the visualization task.""" + db_task = ProcessingTask( + task_name=visualization_task.name, + parameters=InputParametersSchema().dumps(arguments), + ) + db_task.save(commit=True) + + # all tasks need to know about db id to load the db entry + task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s( + db_id=db_task.id + ) + # save errors to db + task.link_error(save_task_error.s(db_id=db_task.id)) + task.apply_async() + + db_task.save(commit=True) + + return redirect( + url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER + ) diff --git a/plugins/zxcalculus/schemas.py b/plugins/zxcalculus/schemas.py new file mode 100644 index 000000000..3870927e8 --- /dev/null +++ b/plugins/zxcalculus/schemas.py @@ -0,0 +1,68 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from marshmallow import post_load +import marshmallow as ma +from qhana_plugin_runner.api.util import FrontendFormBaseSchema, MaBaseSchema, FileUrl +from dataclasses import dataclass + + +class TaskResponseSchema(MaBaseSchema): + name = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_id = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_result_url = ma.fields.Url(required=True, allow_none=False, dump_only=True) + + +@dataclass(repr=False) +class InputParameters: + qubits: int + depth: int + simplify: bool = False + + def __str__(self): + return str(self.__dict__) + + +class InputParametersSchema(FrontendFormBaseSchema): + qubits = ma.fields.Integer( + required=True, + allow_none=False, + metadata={ + "label": "No. Qubits", + "description": "Determines the number of qubits to generate.", + "input_type": "number", + }, + ) + depth = ma.fields.Integer( + required=True, + allow_none=False, + metadata={ + "label": "Depth", + "description": "Determines the depth of the circuits.", + "input_type": "number", + }, + ) + simplify = ma.fields.Boolean( + required=False, + allow_none=False, + metadata={ + "label": "Simplify", + "description": "Simplify the generated circuit.", + "input_type": "checkbox", + } + ) + + @post_load + def make_input_params(self, data, **kwargs) -> InputParameters: + return InputParameters(**data) diff --git a/plugins/zxcalculus/tasks.py b/plugins/zxcalculus/tasks.py new file mode 100644 index 000000000..4ade83ec6 --- /dev/null +++ b/plugins/zxcalculus/tasks.py @@ -0,0 +1,79 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tempfile import SpooledTemporaryFile + +from typing import Optional + +from celery.utils.log import get_task_logger + +import muid +from . import ZXCalculus +from .schemas import InputParameters, InputParametersSchema +from qhana_plugin_runner.celery import CELERY +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.requests import open_url, retrieve_filename + +from qhana_plugin_runner.storage import STORE + +import pyzx as zx +import matplotlib.pyplot as _, mpld3 + +TASK_LOGGER = get_task_logger(__name__) + + +def get_readable_hash(s: str) -> str: + return muid.pretty(muid.bhash(s.encode("utf-8")), k1=6, k2=5).replace(" ", "-") + + +@CELERY.task(name=f"{ZXCalculus.instance.identifier}.visualization_task", bind=True) +def visualization_task(self, db_id: int) -> str: + + TASK_LOGGER.info(f"Starting new demo task with db id '{db_id}'") + task_data: Optional[ProcessingTask] = ProcessingTask.get_by_id(id_=db_id) + + if task_data is None: + msg = f"Could not load task data with id {db_id} to read parameters!" + TASK_LOGGER.error(msg) + raise KeyError(msg) + + input_params: InputParameters = InputParametersSchema().loads(task_data.parameters) + + qubits = input_params.qubits + depth = input_params.depth + simplify = input_params.simplify + + circuit = zx.generate.cliffordT(qubits, depth) + zx.settings.drawing_backend = 'd3' + fig = zx.draw(circuit) + html = mpld3.fig_to_html(fig) + if simplify: + zx.simplify.full_reduce(circuit) + circuit.normalize() + fig = zx.draw(circuit) + html = mpld3.fig_to_html(fig) + + with SpooledTemporaryFile(mode="wt") as output: + output.write(html) + + STORE.persist_task_result( + db_id, + output, + f"ZXCalculus_Circuit_{qubits}Qubits_{depth}Depth.html", + "circuit", + "text/html", + ) + + return "Result stored in file" diff --git a/stable_plugins/data_synthesis/data_creator/__init__.py b/stable_plugins/data_synthesis/data_creator/__init__.py index a75ad9f37..10f3181b3 100644 --- a/stable_plugins/data_synthesis/data_creator/__init__.py +++ b/stable_plugins/data_synthesis/data_creator/__init__.py @@ -22,7 +22,7 @@ _plugin_name = "data-creator" -__version__ = "v0.2.1" +__version__ = "v0.2.2" _identifier = plugin_identifier(_plugin_name, __version__) diff --git a/stable_plugins/data_synthesis/data_creator/backend/datasets.py b/stable_plugins/data_synthesis/data_creator/backend/datasets.py index bdd820844..4f7aa3784 100644 --- a/stable_plugins/data_synthesis/data_creator/backend/datasets.py +++ b/stable_plugins/data_synthesis/data_creator/backend/datasets.py @@ -14,6 +14,7 @@ from enum import Enum from typing import List, Tuple +from sklearn.datasets import make_blobs import numpy as np @@ -21,6 +22,9 @@ class DataTypeEnum(Enum): two_spirals = "Two Spirals" checkerboard = "Checkerboard" + blobs = "Blobs" + checkerboard_3d = "3D Checkerboard" + blobs_3d = "3D Blobs" def get_data( self, num_train_points: int, num_test_points: int, **kwargs @@ -30,6 +34,12 @@ def get_data( data, labels = twospirals(num_train_points + num_test_points, **kwargs) elif self == DataTypeEnum.checkerboard: data, labels = checkerboard(num_train_points + num_test_points, **kwargs) + elif self == DataTypeEnum.blobs: + data, labels = blobs(num_train_points + num_test_points, **kwargs) + elif self == DataTypeEnum.checkerboard_3d: + data, labels = checkerboard_3d(num_train_points + num_test_points, **kwargs) + elif self == DataTypeEnum.blobs_3d: + data, labels = blobs_3d(num_train_points + num_test_points, **kwargs) else: raise NotImplementedError indices = np.arange(len(data)) @@ -68,3 +78,39 @@ def checkerboard(n_points: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: y += 0.2 if y > 0 else -0.2 rand_points[i] = x, y return rand_points, labels + +def blobs(n_points: int, centers: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Returns the Blobs dataset.""" + x, y = make_blobs( + n_samples=n_points, + centers=centers, + n_features=2, + ) + + return x, y + +def checkerboard_3d(n_points: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Returns the 3D checkerboard dataset.""" + rand_points = (np.random.rand(n_points, 3) * 2) - 1 + labels = np.zeros(n_points).astype(int) + + for i, (x, y, z) in enumerate(rand_points): + # label by quadrant + labels[i] = int(not ((x < 0 and y < 0) or (x >= 0 and y >= 0)) != (z < 0)) + # push away from both axes + x += 0.2 if x > 0 else -0.2 + y += 0.2 if y > 0 else -0.2 + z += 0.2 if z > 0 else -0.2 + rand_points[i] = x, y, z + return rand_points, labels + + +def blobs_3d(n_points: int, centers:int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Returns the 3D Blobs dataset.""" + x, y = make_blobs( + n_samples=n_points, + centers=centers, + n_features=3, + ) + + return x, y \ No newline at end of file diff --git a/stable_plugins/data_synthesis/data_creator/routes.py b/stable_plugins/data_synthesis/data_creator/routes.py index 5c0cca62d..9a2fa95f4 100644 --- a/stable_plugins/data_synthesis/data_creator/routes.py +++ b/stable_plugins/data_synthesis/data_creator/routes.py @@ -139,6 +139,7 @@ def render(self, data: Mapping, errors: dict): default_values = { fields["noise"].data_key: 0.7, fields["turns"].data_key: 1.52, + fields["centers"].data_key: 4, } # overwrite default values with other values if possible diff --git a/stable_plugins/data_synthesis/data_creator/schemas.py b/stable_plugins/data_synthesis/data_creator/schemas.py index 95c2e6b56..d117086e4 100644 --- a/stable_plugins/data_synthesis/data_creator/schemas.py +++ b/stable_plugins/data_synthesis/data_creator/schemas.py @@ -39,6 +39,7 @@ class InputParameters: num_test_points: int turns: float = None noise: float = None + centers: int = None def __str__(self): return str(self.__dict__.copy()) @@ -93,6 +94,15 @@ class InputParametersSchema(FrontendFormBaseSchema): "input_type": "text", }, ) + centers = ma.fields.Integer( + required=False, + allow_none=False, + metadata={ + "label": "No. Centers", + "description": "Determines the number of Blobs", + "input_type": "number", + }, + ) @post_load def make_input_params(self, data, **kwargs) -> InputParameters: diff --git a/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html b/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html index d6f74fad7..f1b9767ed 100644 --- a/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html +++ b/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html @@ -4,16 +4,21 @@