diff --git a/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py b/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py
new file mode 100644
index 000000000..75ce158fb
--- /dev/null
+++ b/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py
@@ -0,0 +1,350 @@
+# Copyright 2023 QHAna plugin runner contributors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+from http import HTTPStatus
+from io import BytesIO
+import pathlib
+from tempfile import SpooledTemporaryFile
+from typing import Mapping, Optional
+from celery import chain
+import celery
+from celery.utils.log import get_task_logger
+from flask import abort, redirect, send_file
+from flask.app import Flask
+from flask.globals import request
+from flask.helpers import url_for
+from flask.templating import render_template
+from flask.views import MethodView
+from flask.wrappers import Response
+from marshmallow import EXCLUDE
+from requests.exceptions import HTTPError
+
+from qhana_plugin_runner.api.plugin_schemas import (
+ DataMetadata,
+ EntryPoint,
+ InputDataMetadata,
+ PluginMetadata,
+ PluginMetadataSchema,
+ PluginType,
+)
+from qhana_plugin_runner.api.util import (
+ FileUrl,
+ FrontendFormBaseSchema,
+ SecurityBlueprint,
+)
+from qhana_plugin_runner.celery import CELERY
+from qhana_plugin_runner.db.models.tasks import ProcessingTask
+from qhana_plugin_runner.requests import open_url
+from qhana_plugin_runner.storage import STORE
+from qhana_plugin_runner.tasks import save_task_error, save_task_result
+from qhana_plugin_runner.util.plugins import QHAnaPluginBase, plugin_identifier
+from qhana_plugin_runner.db.models.virtual_plugins import DataBlob, PluginState
+
+_plugin_name = "cluster-scatter-visualization"
+__version__ = "v0.0.1"
+_identifier = plugin_identifier(_plugin_name, __version__)
+
+
+VIS_BLP = SecurityBlueprint(
+ _identifier, # blueprint name
+ __name__, # module import name!
+ description="A visualization plugin for cluster scatter data.",
+ template_folder="cluster_scatter_visualization_templates",
+)
+
+
+class CSInputParametersSchema(FrontendFormBaseSchema):
+ entity_url = FileUrl(
+ required=True,
+ allow_none=False,
+ data_input_type="entity/vector",
+ data_content_types=["application/json"],
+ metadata={
+ "label": "Entity Point URL",
+ "description": "URL to a json file containing the points.",
+ "input_type": "text",
+ },
+ )
+ clusters_url = FileUrl(
+ required=True,
+ allow_none=True,
+ data_input_type="entity/vector",
+ data_content_types=["application/json"],
+ metadata={
+ "label": "Cluster URL",
+ "description": "URL to a json file containing the cluster labels.",
+ "input_type": "text",
+ },
+ )
+
+
+@VIS_BLP.route("/")
+class PluginsView(MethodView):
+ """Plugins collection resource."""
+
+ @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema())
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def get(self):
+ """Endpoint returning the plugin metadata."""
+ plugin = CSVisualization.instance
+ if plugin is None:
+ abort(HTTPStatus.INTERNAL_SERVER_ERROR)
+ return PluginMetadata(
+ title=plugin.name,
+ description=plugin.description,
+ name=plugin.name,
+ version=plugin.version,
+ type=PluginType.visualization,
+ entry_point=EntryPoint(
+ href=url_for(f"{VIS_BLP.name}.ProcessView"),
+ ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"),
+ plugin_dependencies=[],
+ data_input=[
+ InputDataMetadata(
+ data_type="entity/vector",
+ content_type=["application/json"],
+ required=True,
+ parameter="entityUrl",
+ ),
+ InputDataMetadata(
+ data_type="entity/label",
+ content_type=["application/json"],
+ required=True,
+ parameter="clustersUrl",
+ )
+ ],
+ data_output=[
+ DataMetadata(
+ data_type="plot",
+ content_type=["image/svg+xml"],
+ required=True,
+ )
+ ],
+ ),
+ tags=["visualization", "cluster", "scatter"],
+ )
+
+
+@VIS_BLP.route("/ui/")
+class MicroFrontend(MethodView):
+ """Micro frontend for the cluster scatter visualization plugin."""
+
+ @VIS_BLP.html_response(
+ HTTPStatus.OK, description="Micro frontend of the cluster scatter visualization plugin."
+ )
+ @VIS_BLP.arguments(
+ CSInputParametersSchema(
+ partial=True, unknown=EXCLUDE, validate_errors_as_result=True
+ ),
+ location="query",
+ required=False,
+ )
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def get(self, errors):
+ """Return the micro frontend."""
+ return self.render(request.args, errors, False)
+
+ @VIS_BLP.html_response(
+ HTTPStatus.OK, description="Micro frontend of the cluster scatter visualization plugin."
+ )
+ @VIS_BLP.arguments(
+ CSInputParametersSchema(
+ partial=True, unknown=EXCLUDE, validate_errors_as_result=True
+ ),
+ location="form",
+ required=False,
+ )
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def post(self, errors):
+ """Return the micro frontend with prerendered inputs."""
+ return self.render(request.form, errors, not errors)
+
+ def render(self, data: Mapping, errors: dict, valid: bool):
+ plugin = CSVisualization.instance
+ if plugin is None:
+ abort(HTTPStatus.INTERNAL_SERVER_ERROR)
+ return Response(
+ render_template(
+ "cluster_scatter_visualization.html",
+ name=plugin.name,
+ version=plugin.version,
+ schema=CSInputParametersSchema(),
+ valid=valid,
+ values=data,
+ errors=errors,
+ example_values=url_for(f"{VIS_BLP.name}.MicroFrontend"),
+ get_circuit_image_url=url_for(f"{VIS_BLP.name}.get_circuit_image"),
+ process=url_for(f"{VIS_BLP.name}.ProcessView"),
+ )
+ )
+
+
+class ImageNotFinishedError(Exception):
+ pass
+
+
+@VIS_BLP.route("/circuits/")
+@VIS_BLP.response(HTTPStatus.OK, description="Circuit image.")
+@VIS_BLP.arguments(
+ CSInputParametersSchema(unknown=EXCLUDE),
+ location="query",
+ required=True,
+)
+@VIS_BLP.require_jwt("jwt", optional=True)
+def get_circuit_image(data: Mapping):
+ url = data.get("data", None)
+ if not url:
+ abort(HTTPStatus.BAD_REQUEST)
+ url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()
+ image = DataBlob.get_value(CSVisualization.instance.identifier, url_hash, None)
+ if image is None:
+ if not (
+ task_id := PluginState.get_value(
+ CSVisualization.instance.identifier, url_hash, None
+ )
+ ):
+ task_result = generate_image.s(url, url_hash).apply_async()
+ PluginState.set_value(
+ CSVisualization.instance.identifier,
+ url_hash,
+ task_result.id,
+ commit=True,
+ )
+ else:
+ task_result = CELERY.AsyncResult(task_id)
+ try:
+ task_result.get(timeout=5)
+ image = DataBlob.get_value(CSVisualization.instance.identifier, url_hash)
+ except celery.exceptions.TimeoutError:
+ return Response("Image not yet created!", HTTPStatus.ACCEPTED)
+ if not image:
+ abort(HTTPStatus.BAD_REQUEST, "Invalid circuit URL!")
+ return send_file(BytesIO(image), mimetype="image/svg+xml")
+
+
+@VIS_BLP.route("/process/")
+class ProcessView(MethodView):
+ """Start a long running processing task."""
+
+ @VIS_BLP.arguments(CSInputParametersSchema(unknown=EXCLUDE), location="form")
+ @VIS_BLP.response(HTTPStatus.SEE_OTHER)
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def post(self, arguments):
+ circuit_url = arguments.get("data", None)
+ if circuit_url is None:
+ abort(HTTPStatus.BAD_REQUEST)
+ url_hash = hashlib.sha256(circuit_url.encode("utf-8")).hexdigest()
+ db_task = ProcessingTask(task_name=process.name)
+ db_task.save(commit=True)
+
+ # all tasks need to know about db id to load the db entry
+ task: chain = process.s(
+ db_id=db_task.id, url=circuit_url, hash=url_hash
+ ) | save_task_result.s(db_id=db_task.id)
+ # save errors to db
+ task.link_error(save_task_error.s(db_id=db_task.id))
+ task.apply_async(db_id=db_task.id, url=circuit_url, hash=url_hash)
+
+ db_task.save(commit=True)
+
+ return redirect(
+ url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER
+ )
+
+
+class CSVisualization(QHAnaPluginBase):
+ name = _plugin_name
+ version = __version__
+ description = "Visualizes cluster data in a scatter plot."
+ tags = ["visualization"]
+
+ def __init__(self, app: Optional[Flask]) -> None:
+ super().__init__(app)
+
+ # create folder for circuit images
+ pathlib.Path(__file__).parent.absolute().joinpath("img").mkdir(
+ parents=True, exist_ok=True
+ )
+
+ def get_api_blueprint(self):
+ return VIS_BLP
+
+ def get_requirements(self) -> str:
+ return "pylatexenc~=2.10\nqiskit~=0.43"
+
+
+TASK_LOGGER = get_task_logger(__name__)
+
+
+@CELERY.task(name=f"{CSVisualization.instance.identifier}.generate_image", bind=True)
+def generate_image(self, url: str, hash: str) -> str:
+ from qiskit import QuantumCircuit
+ import matplotlib
+
+ matplotlib.use("SVG")
+
+ TASK_LOGGER.info(f"Generating image for circuit {url}...")
+ try:
+ with open_url(url) as qasm_response:
+ circuit_qasm = qasm_response.text
+ except HTTPError:
+ TASK_LOGGER.error(f"Invalid circuit URL: {url}")
+ DataBlob.set_value(
+ CSVisualization.instance.identifier,
+ hash,
+ "",
+ )
+ PluginState.delete_value(CSVisualization.instance.identifier, hash, commit=True)
+ return "Invalid circuit URL!"
+
+ circuit = QuantumCircuit.from_qasm_str(circuit_qasm)
+ fig = circuit.draw(output="mpl", interactive=False)
+ figfile = BytesIO()
+ fig.savefig(figfile, format="svg")
+ figfile.seek(0)
+ DataBlob.set_value(CSVisualization.instance.identifier, hash, figfile.getvalue())
+ TASK_LOGGER.info(f"Stored image of circuit {circuit.name}.")
+ PluginState.delete_value(CSVisualization.instance.identifier, hash, commit=True)
+
+ return "Created image of circuit!"
+
+
+@CELERY.task(
+ name=f"{CSVisualization.instance.identifier}.process",
+ bind=True,
+ autoretry_for=(ImageNotFinishedError,),
+ retry_backoff=True,
+ max_retries=None,
+)
+def process(self, db_id: str, url: str, hash: str) -> str:
+ if not (image := DataBlob.get_value(CSVisualization.instance.identifier, hash)):
+ if not (
+ task_id := PluginState.get_value(CSVisualization.instance.identifier, hash)
+ ):
+ task_result = generate_image.s(url, hash).apply_async()
+ PluginState.set_value(
+ CSVisualization.instance.identifier,
+ hash,
+ task_result.id,
+ commit=True,
+ )
+ raise ImageNotFinishedError()
+ with SpooledTemporaryFile() as output:
+ output.write(image)
+ output.seek(0)
+ STORE.persist_task_result(
+ db_id, output, f"circuit_{hash}.svg", "image/svg", "image/svg+xml"
+ )
+ return "Created image of circuit!"
diff --git a/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html b/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html
new file mode 100644
index 000000000..5a754fe20
--- /dev/null
+++ b/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html
@@ -0,0 +1,126 @@
+{% extends "simple_template.html" %}
+
+{% block head %}
+{{ super() }}
+
+{% endblock head %}
+
+{% block content %}
+
+
+ Visualization Options
+ {% call forms.render_form(target="microfrontend") %}
+ {{ forms.render_fields(schema, values=values, errors=errors) }}
+
+ {{ forms.submit("validate", target="microfrontend") }}
+ {{ forms.submit("submit", action=process) }}
+
+ {% endcall %}
+
+
+
+
+
+{% endblock content %}
+
+{% block script %}
+{{ super() }}
+
+{% endblock script %}
\ No newline at end of file
diff --git a/plugins/cluster_svm_visualization/tasks.py b/plugins/cluster_svm_visualization/tasks.py
index 5190cf6c2..ebffbc1b4 100644
--- a/plugins/cluster_svm_visualization/tasks.py
+++ b/plugins/cluster_svm_visualization/tasks.py
@@ -66,6 +66,7 @@ def visualization_task(self, db_id: int) -> str:
pt_x_list = [0 for _ in range(0, len(entity_points))]
pt_y_list = [0 for _ in range(0, len(entity_points))]
+ pt_z_list = [0 for _ in range(0, len(entity_points))]
label_list = [0 for _ in range(0, len(entity_points))]
id_list = [x for x in range(0, len(entity_points))]
size_list = [10 for _ in range(0, len(entity_points))]
@@ -75,6 +76,8 @@ def visualization_task(self, db_id: int) -> str:
idx = int(pt["ID"])
pt_x_list[idx] = pt["dim0"]
pt_y_list[idx] = pt["dim1"]
+ if do_3d:
+ pt_z_list[idx] = pt["dim2"]
for cl in clusters:
label_list[int(cl["ID"])] = cl["label"]
@@ -85,20 +88,33 @@ def visualization_task(self, db_id: int) -> str:
"ID": [f"Point {x}" for x in id_list],
"x": pt_x_list,
"y": pt_y_list,
+ "z": pt_z_list,
"Cluster ID": [str(x) for x in label_list],
"size": size_list,
}
)
- fig = px.scatter(
- df,
- x="x",
- y="y",
- size="size",
- hover_name="ID",
- color="Cluster ID",
- hover_data={"size": False},
- )
+ if not do_3d:
+ fig = px.scatter(
+ df,
+ x="x",
+ y="y",
+ size="size",
+ hover_name="ID",
+ color="Cluster ID",
+ hover_data={"size": False},
+ )
+ else:
+ fig = px.scatter_3d(
+ df,
+ x="x",
+ y="y",
+ z="z",
+ size="size",
+ hover_name="ID",
+ color="Cluster ID",
+ hover_data={"size": False},
+ )
if do_svm:
cluster_list = [[] for _ in range(0, max_cluster + 1)]
diff --git a/plugins/zxcalculus/__init__.py b/plugins/zxcalculus/__init__.py
new file mode 100644
index 000000000..f083564b2
--- /dev/null
+++ b/plugins/zxcalculus/__init__.py
@@ -0,0 +1,57 @@
+# Copyright 2022 QHAna plugin runner contributors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from flask import Flask
+
+from qhana_plugin_runner.api.util import SecurityBlueprint
+from qhana_plugin_runner.util.plugins import plugin_identifier, QHAnaPluginBase
+
+_plugin_name = "zxcalculus"
+__version__ = "v0.0.4"
+_identifier = plugin_identifier(_plugin_name, __version__)
+
+
+VIS_BLP = SecurityBlueprint(
+ _identifier, # blueprint name
+ __name__, # module import name!
+ description="ZXCalculus API.",
+)
+
+
+class ZXCalculus(QHAnaPluginBase):
+ name = _plugin_name
+ version = __version__
+ description = "Generates a random circuit, visualizes and simplifies it."
+ tags = ["zxcalculus", "circuit"]
+
+ def __init__(self, app: Optional[Flask]) -> None:
+ super().__init__(app)
+
+ def get_api_blueprint(self):
+ return VIS_BLP
+
+ def get_requirements(self) -> str:
+ return "plotly~=5.18.0\npyzx~=0.8.0\nmpld3~=0.5.10"
+
+
+try:
+ # It is important to import the routes **after** COSTUME_LOADER_BLP and CostumeLoader are defined, because they are
+ # accessed as soon as the routes are imported.
+ from . import routes
+except ImportError:
+ # When running `poetry run flask install`, importing the routes will fail, because the dependencies are not
+ # installed yet.
+ pass
diff --git a/plugins/zxcalculus/routes.py b/plugins/zxcalculus/routes.py
new file mode 100644
index 000000000..003fff128
--- /dev/null
+++ b/plugins/zxcalculus/routes.py
@@ -0,0 +1,160 @@
+# Copyright 2023 QHAna plugin runner contributors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Mapping
+
+from celery.canvas import chain
+from flask import Response, redirect
+from flask.globals import request
+from flask.helpers import url_for
+from flask.templating import render_template
+from flask.views import MethodView
+from marshmallow import EXCLUDE
+
+from qhana_plugin_runner.api.plugin_schemas import (
+ EntryPoint,
+ PluginMetadata,
+ PluginMetadataSchema,
+ PluginType,
+ InputDataMetadata,
+ DataMetadata,
+)
+from qhana_plugin_runner.db.models.tasks import ProcessingTask
+from qhana_plugin_runner.tasks import (
+ save_task_error,
+ save_task_result,
+)
+
+from . import VIS_BLP, ZXCalculus
+from .schemas import InputParametersSchema, TaskResponseSchema
+from .tasks import visualization_task
+
+
+@VIS_BLP.route("/")
+class PluginsView(MethodView):
+ """Plugins collection resource."""
+
+ @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema)
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def get(self):
+ """Endpoint returning the plugin metadata."""
+
+ return PluginMetadata(
+ title="ZXCalculus",
+ description=ZXCalculus.instance.description,
+ name=ZXCalculus.instance.name,
+ version=ZXCalculus.instance.version,
+ type=PluginType.visualization,
+ entry_point=EntryPoint(
+ href=url_for(f"{VIS_BLP.name}.ProcessView"),
+ ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"),
+ data_input=[],
+ data_output=[
+ DataMetadata(
+ data_type="circuit", content_type=["text/html"], required=True
+ )
+ ],
+ ),
+ tags=ZXCalculus.instance.tags,
+ )
+
+
+@VIS_BLP.route("/ui/")
+class MicroFrontend(MethodView):
+ """Micro frontend for the ZXCalculus plugin."""
+
+ @VIS_BLP.html_response(
+ HTTPStatus.OK, description="Micro frontend for the ZXCalculus plugin."
+ )
+ @VIS_BLP.arguments(
+ InputParametersSchema(
+ partial=True, unknown=EXCLUDE, validate_errors_as_result=True
+ ),
+ location="query",
+ required=False,
+ )
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def get(self, errors):
+ """Return the micro frontend."""
+ return self.render(request.args, errors)
+
+ @VIS_BLP.html_response(
+ HTTPStatus.OK, description="Micro frontend for ZXCalculus plugin."
+ )
+ @VIS_BLP.arguments(
+ InputParametersSchema(
+ partial=True, unknown=EXCLUDE, validate_errors_as_result=True
+ ),
+ location="form",
+ required=False,
+ )
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def post(self, errors):
+ """Return the micro frontend with prerendered inputs."""
+ return self.render(request.form, errors)
+
+ def render(self, data: Mapping, errors: dict):
+ schema = InputParametersSchema()
+
+ data_dict = dict(data)
+
+ # define default values
+ default_values = {}
+
+ # overwrite default values with other values if possible
+ default_values.update(data_dict)
+ data_dict = default_values
+
+ return Response(
+ render_template(
+ "simple_template.html",
+ name=ZXCalculus.instance.name,
+ version=ZXCalculus.instance.version,
+ schema=schema,
+ values=data_dict,
+ errors=errors,
+ process=url_for(f"{VIS_BLP.name}.ProcessView"),
+ )
+ )
+
+
+@VIS_BLP.route("/process/")
+class ProcessView(MethodView):
+ """Start a long running processing task."""
+
+ @VIS_BLP.arguments(InputParametersSchema(unknown=EXCLUDE), location="form")
+ @VIS_BLP.response(HTTPStatus.OK, TaskResponseSchema())
+ @VIS_BLP.require_jwt("jwt", optional=True)
+ def post(self, arguments):
+ """Start the visualization task."""
+ db_task = ProcessingTask(
+ task_name=visualization_task.name,
+ parameters=InputParametersSchema().dumps(arguments),
+ )
+ db_task.save(commit=True)
+
+ # all tasks need to know about db id to load the db entry
+ task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s(
+ db_id=db_task.id
+ )
+ # save errors to db
+ task.link_error(save_task_error.s(db_id=db_task.id))
+ task.apply_async()
+
+ db_task.save(commit=True)
+
+ return redirect(
+ url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER
+ )
diff --git a/plugins/zxcalculus/schemas.py b/plugins/zxcalculus/schemas.py
new file mode 100644
index 000000000..3870927e8
--- /dev/null
+++ b/plugins/zxcalculus/schemas.py
@@ -0,0 +1,68 @@
+# Copyright 2023 QHAna plugin runner contributors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from marshmallow import post_load
+import marshmallow as ma
+from qhana_plugin_runner.api.util import FrontendFormBaseSchema, MaBaseSchema, FileUrl
+from dataclasses import dataclass
+
+
+class TaskResponseSchema(MaBaseSchema):
+ name = ma.fields.String(required=True, allow_none=False, dump_only=True)
+ task_id = ma.fields.String(required=True, allow_none=False, dump_only=True)
+ task_result_url = ma.fields.Url(required=True, allow_none=False, dump_only=True)
+
+
+@dataclass(repr=False)
+class InputParameters:
+ qubits: int
+ depth: int
+ simplify: bool = False
+
+ def __str__(self):
+ return str(self.__dict__)
+
+
+class InputParametersSchema(FrontendFormBaseSchema):
+ qubits = ma.fields.Integer(
+ required=True,
+ allow_none=False,
+ metadata={
+ "label": "No. Qubits",
+ "description": "Determines the number of qubits to generate.",
+ "input_type": "number",
+ },
+ )
+ depth = ma.fields.Integer(
+ required=True,
+ allow_none=False,
+ metadata={
+ "label": "Depth",
+ "description": "Determines the depth of the circuits.",
+ "input_type": "number",
+ },
+ )
+ simplify = ma.fields.Boolean(
+ required=False,
+ allow_none=False,
+ metadata={
+ "label": "Simplify",
+ "description": "Simplify the generated circuit.",
+ "input_type": "checkbox",
+ }
+ )
+
+ @post_load
+ def make_input_params(self, data, **kwargs) -> InputParameters:
+ return InputParameters(**data)
diff --git a/plugins/zxcalculus/tasks.py b/plugins/zxcalculus/tasks.py
new file mode 100644
index 000000000..4ade83ec6
--- /dev/null
+++ b/plugins/zxcalculus/tasks.py
@@ -0,0 +1,79 @@
+# Copyright 2023 QHAna plugin runner contributors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tempfile import SpooledTemporaryFile
+
+from typing import Optional
+
+from celery.utils.log import get_task_logger
+
+import muid
+from . import ZXCalculus
+from .schemas import InputParameters, InputParametersSchema
+from qhana_plugin_runner.celery import CELERY
+from qhana_plugin_runner.db.models.tasks import ProcessingTask
+from qhana_plugin_runner.requests import open_url, retrieve_filename
+
+from qhana_plugin_runner.storage import STORE
+
+import pyzx as zx
+import matplotlib.pyplot as _, mpld3
+
+TASK_LOGGER = get_task_logger(__name__)
+
+
+def get_readable_hash(s: str) -> str:
+ return muid.pretty(muid.bhash(s.encode("utf-8")), k1=6, k2=5).replace(" ", "-")
+
+
+@CELERY.task(name=f"{ZXCalculus.instance.identifier}.visualization_task", bind=True)
+def visualization_task(self, db_id: int) -> str:
+
+ TASK_LOGGER.info(f"Starting new demo task with db id '{db_id}'")
+ task_data: Optional[ProcessingTask] = ProcessingTask.get_by_id(id_=db_id)
+
+ if task_data is None:
+ msg = f"Could not load task data with id {db_id} to read parameters!"
+ TASK_LOGGER.error(msg)
+ raise KeyError(msg)
+
+ input_params: InputParameters = InputParametersSchema().loads(task_data.parameters)
+
+ qubits = input_params.qubits
+ depth = input_params.depth
+ simplify = input_params.simplify
+
+ circuit = zx.generate.cliffordT(qubits, depth)
+ zx.settings.drawing_backend = 'd3'
+ fig = zx.draw(circuit)
+ html = mpld3.fig_to_html(fig)
+ if simplify:
+ zx.simplify.full_reduce(circuit)
+ circuit.normalize()
+ fig = zx.draw(circuit)
+ html = mpld3.fig_to_html(fig)
+
+ with SpooledTemporaryFile(mode="wt") as output:
+ output.write(html)
+
+ STORE.persist_task_result(
+ db_id,
+ output,
+ f"ZXCalculus_Circuit_{qubits}Qubits_{depth}Depth.html",
+ "circuit",
+ "text/html",
+ )
+
+ return "Result stored in file"
diff --git a/stable_plugins/data_synthesis/data_creator/__init__.py b/stable_plugins/data_synthesis/data_creator/__init__.py
index a75ad9f37..10f3181b3 100644
--- a/stable_plugins/data_synthesis/data_creator/__init__.py
+++ b/stable_plugins/data_synthesis/data_creator/__init__.py
@@ -22,7 +22,7 @@
_plugin_name = "data-creator"
-__version__ = "v0.2.1"
+__version__ = "v0.2.2"
_identifier = plugin_identifier(_plugin_name, __version__)
diff --git a/stable_plugins/data_synthesis/data_creator/backend/datasets.py b/stable_plugins/data_synthesis/data_creator/backend/datasets.py
index bdd820844..4f7aa3784 100644
--- a/stable_plugins/data_synthesis/data_creator/backend/datasets.py
+++ b/stable_plugins/data_synthesis/data_creator/backend/datasets.py
@@ -14,6 +14,7 @@
from enum import Enum
from typing import List, Tuple
+from sklearn.datasets import make_blobs
import numpy as np
@@ -21,6 +22,9 @@
class DataTypeEnum(Enum):
two_spirals = "Two Spirals"
checkerboard = "Checkerboard"
+ blobs = "Blobs"
+ checkerboard_3d = "3D Checkerboard"
+ blobs_3d = "3D Blobs"
def get_data(
self, num_train_points: int, num_test_points: int, **kwargs
@@ -30,6 +34,12 @@ def get_data(
data, labels = twospirals(num_train_points + num_test_points, **kwargs)
elif self == DataTypeEnum.checkerboard:
data, labels = checkerboard(num_train_points + num_test_points, **kwargs)
+ elif self == DataTypeEnum.blobs:
+ data, labels = blobs(num_train_points + num_test_points, **kwargs)
+ elif self == DataTypeEnum.checkerboard_3d:
+ data, labels = checkerboard_3d(num_train_points + num_test_points, **kwargs)
+ elif self == DataTypeEnum.blobs_3d:
+ data, labels = blobs_3d(num_train_points + num_test_points, **kwargs)
else:
raise NotImplementedError
indices = np.arange(len(data))
@@ -68,3 +78,39 @@ def checkerboard(n_points: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
y += 0.2 if y > 0 else -0.2
rand_points[i] = x, y
return rand_points, labels
+
+def blobs(n_points: int, centers: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ """Returns the Blobs dataset."""
+ x, y = make_blobs(
+ n_samples=n_points,
+ centers=centers,
+ n_features=2,
+ )
+
+ return x, y
+
+def checkerboard_3d(n_points: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ """Returns the 3D checkerboard dataset."""
+ rand_points = (np.random.rand(n_points, 3) * 2) - 1
+ labels = np.zeros(n_points).astype(int)
+
+ for i, (x, y, z) in enumerate(rand_points):
+ # label by quadrant
+ labels[i] = int(not ((x < 0 and y < 0) or (x >= 0 and y >= 0)) != (z < 0))
+ # push away from both axes
+ x += 0.2 if x > 0 else -0.2
+ y += 0.2 if y > 0 else -0.2
+ z += 0.2 if z > 0 else -0.2
+ rand_points[i] = x, y, z
+ return rand_points, labels
+
+
+def blobs_3d(n_points: int, centers:int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ """Returns the 3D Blobs dataset."""
+ x, y = make_blobs(
+ n_samples=n_points,
+ centers=centers,
+ n_features=3,
+ )
+
+ return x, y
\ No newline at end of file
diff --git a/stable_plugins/data_synthesis/data_creator/routes.py b/stable_plugins/data_synthesis/data_creator/routes.py
index 5c0cca62d..9a2fa95f4 100644
--- a/stable_plugins/data_synthesis/data_creator/routes.py
+++ b/stable_plugins/data_synthesis/data_creator/routes.py
@@ -139,6 +139,7 @@ def render(self, data: Mapping, errors: dict):
default_values = {
fields["noise"].data_key: 0.7,
fields["turns"].data_key: 1.52,
+ fields["centers"].data_key: 4,
}
# overwrite default values with other values if possible
diff --git a/stable_plugins/data_synthesis/data_creator/schemas.py b/stable_plugins/data_synthesis/data_creator/schemas.py
index 95c2e6b56..d117086e4 100644
--- a/stable_plugins/data_synthesis/data_creator/schemas.py
+++ b/stable_plugins/data_synthesis/data_creator/schemas.py
@@ -39,6 +39,7 @@ class InputParameters:
num_test_points: int
turns: float = None
noise: float = None
+ centers: int = None
def __str__(self):
return str(self.__dict__.copy())
@@ -93,6 +94,15 @@ class InputParametersSchema(FrontendFormBaseSchema):
"input_type": "text",
},
)
+ centers = ma.fields.Integer(
+ required=False,
+ allow_none=False,
+ metadata={
+ "label": "No. Centers",
+ "description": "Determines the number of Blobs",
+ "input_type": "number",
+ },
+ )
@post_load
def make_input_params(self, data, **kwargs) -> InputParameters:
diff --git a/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html b/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html
index d6f74fad7..f1b9767ed 100644
--- a/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html
+++ b/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html
@@ -4,16 +4,21 @@