From 05e70325595657e66792f13e736043415d60173c Mon Sep 17 00:00:00 2001 From: B3at Date: Mon, 19 Aug 2024 20:41:12 +0200 Subject: [PATCH 1/7] Added Simple Bar Diagram Added Simple Bar Diagram Plugin, which uses Cluster Data generated by for example k-means to show the amount per cluster in a Bar Diagram. --- plugins/bar_simple/__init__.py | 56 +++++++++++ plugins/bar_simple/routes.py | 170 +++++++++++++++++++++++++++++++++ plugins/bar_simple/schemas.py | 51 ++++++++++ plugins/bar_simple/tasks.py | 98 +++++++++++++++++++ 4 files changed, 375 insertions(+) create mode 100644 plugins/bar_simple/__init__.py create mode 100644 plugins/bar_simple/routes.py create mode 100644 plugins/bar_simple/schemas.py create mode 100644 plugins/bar_simple/tasks.py diff --git a/plugins/bar_simple/__init__.py b/plugins/bar_simple/__init__.py new file mode 100644 index 000000000..a2550eebd --- /dev/null +++ b/plugins/bar_simple/__init__.py @@ -0,0 +1,56 @@ +# Copyright 2022 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from flask import Flask + +from qhana_plugin_runner.api.util import SecurityBlueprint +from qhana_plugin_runner.util.plugins import plugin_identifier, QHAnaPluginBase + +_plugin_name = "bar-diagram" +__version__ = "v0.0.6" +_identifier = plugin_identifier(_plugin_name, __version__) + + +BAR_BLP = SecurityBlueprint( + _identifier, # blueprint name + __name__, # module import name! + description="Simple Bar Diagram API.", +) + +class BarDiagram(QHAnaPluginBase): + name = _plugin_name + version = __version__ + description = "Prints provided Data into a simple Bar Diagram." + tags = ["bar-diagram"] + + def __init__(self, app: Optional[Flask]) -> None: + super().__init__(app) + + def get_api_blueprint(self): + return BAR_BLP + + def get_requirements(self) -> str: + return "plotly~=5.18.0\npandas~=1.5.0\nmuid~=0.5.3" + + +try: + # It is important to import the routes **after** COSTUME_LOADER_BLP and CostumeLoader are defined, because they are + # accessed as soon as the routes are imported. + from . import routes +except ImportError: + # When running `poetry run flask install`, importing the routes will fail, because the dependencies are not + # installed yet. + pass diff --git a/plugins/bar_simple/routes.py b/plugins/bar_simple/routes.py new file mode 100644 index 000000000..bca50e4f6 --- /dev/null +++ b/plugins/bar_simple/routes.py @@ -0,0 +1,170 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from json import dumps +from typing import Mapping + +from celery.canvas import chain +from celery.utils.log import get_task_logger +from flask import Response, redirect, abort +from flask.globals import current_app, request +from flask.helpers import url_for +from flask.templating import render_template +from flask.views import MethodView +from marshmallow import EXCLUDE + +from qhana_plugin_runner.api.plugin_schemas import ( + EntryPoint, + PluginMetadata, + PluginMetadataSchema, + PluginType, + InputDataMetadata, + DataMetadata +) +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.tasks import ( + TASK_STEPS_CHANGED, + add_step, + save_task_error, + save_task_result, +) + +from . import BAR_BLP, BarDiagram +from .schemas import InputParametersSchema, TaskResponseSchema +from .tasks import visualization_task + +@BAR_BLP.route("/") +class PluginsView(MethodView): + """Plugins collection resource.""" + + @BAR_BLP.response(HTTPStatus.OK, PluginMetadataSchema) + @BAR_BLP.require_jwt("jwt", optional=True) + def get(self): + """Endpoint returning the plugin metadata.""" + + return PluginMetadata( + title="Bar Diagram", + description=BarDiagram.instance.description, + name=BarDiagram.instance.name, + version=BarDiagram.instance.version, + type=PluginType.visualization, + entry_point=EntryPoint( + href=url_for(f"{BAR_BLP.name}.ProcessView"), + ui_href=url_for(f"{BAR_BLP.name}.MicroFrontend"), + data_input=[ + InputDataMetadata( + data_type="entity/label", + content_type=["application/json"], + required=True, + parameter="clustersUrl", + ) + ], + data_output=[ + DataMetadata( + data_type="plot", + content_type=["text/html"], + required=True + ) + ], + ), + tags=BarDiagram.instance.tags, + ) + + +@BAR_BLP.route("/ui/") +class MicroFrontend(MethodView): + """Micro frontend for the Simple Bar Diagram plugin.""" + + @BAR_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for the Simple Bar Diagram plugin." + ) + @BAR_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="query", + required=False, + ) + @BAR_BLP.require_jwt("jwt", optional=True) + def get(self, errors): + """Return the micro frontend.""" + return self.render(request.args, errors) + + @BAR_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for the Simple Bar Diagram plugin." + ) + @BAR_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="form", + required=False, + ) + @BAR_BLP.require_jwt("jwt", optional=True) + def post(self, errors): + """Return the micro frontend with prerendered inputs.""" + return self.render(request.form, errors) + + def render(self, data: Mapping, errors: dict): + schema = InputParametersSchema() + + data_dict = dict(data) + + # define default values + default_values = {} + + # overwrite default values with other values if possible + default_values.update(data_dict) + data_dict = default_values + + return Response( + render_template( + "simple_template.html", + name=BarDiagram.instance.name, + version=BarDiagram.instance.version, + schema=schema, + values=data_dict, + errors=errors, + process=url_for(f"{BAR_BLP.name}.ProcessView"), + ) + ) + + +@BAR_BLP.route("/process/") +class ProcessView(MethodView): + """Start a long running processing task.""" + + @BAR_BLP.arguments(InputParametersSchema(unknown=EXCLUDE), location="form") + @BAR_BLP.response(HTTPStatus.OK, TaskResponseSchema()) + @BAR_BLP.require_jwt("jwt", optional=True) + def post(self, arguments): + """Start the visualization task.""" + db_task = ProcessingTask( + task_name=visualization_task.name, + parameters=InputParametersSchema().dumps(arguments) + ) + db_task.save(commit=True) + + # all tasks need to know about db id to load the db entry + task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s(db_id=db_task.id) + # save errors to db + task.link_error(save_task_error.s(db_id=db_task.id)) + task.apply_async() + + db_task.save(commit=True) + + return redirect( + url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER + ) diff --git a/plugins/bar_simple/schemas.py b/plugins/bar_simple/schemas.py new file mode 100644 index 000000000..716274fb8 --- /dev/null +++ b/plugins/bar_simple/schemas.py @@ -0,0 +1,51 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from marshmallow import post_load +import marshmallow as ma +from qhana_plugin_runner.api.util import ( + FrontendFormBaseSchema, + MaBaseSchema, + FileUrl +) +from dataclasses import dataclass + +class TaskResponseSchema(MaBaseSchema): + name = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_id = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_result_url = ma.fields.Url(required=True, allow_none=False, dump_only=True) + +@dataclass(repr=False) +class InputParameters: + clusters_url : str + + def __str__(self): + return str(self.__dict__) + +class InputParametersSchema(FrontendFormBaseSchema): + clusters_url = FileUrl( + required=True, + allow_none=False, + data_input_type="entity/label", + data_content_types=["application/json"], + metadata={ + "label": "Cluster points URL", + "description": "URL to a json file with the cluster points.", + "input_type": "text", + }, + ) + + @post_load + def make_input_params(self, data, **kwargs) -> InputParameters: + return InputParameters(**data) \ No newline at end of file diff --git a/plugins/bar_simple/tasks.py b/plugins/bar_simple/tasks.py new file mode 100644 index 000000000..0963a7881 --- /dev/null +++ b/plugins/bar_simple/tasks.py @@ -0,0 +1,98 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tempfile import SpooledTemporaryFile + +from typing import Optional +from json import loads + +from celery.utils.log import get_task_logger + +import muid +from . import BarDiagram +from .schemas import InputParameters, InputParametersSchema +from qhana_plugin_runner.celery import CELERY +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.requests import open_url, retrieve_filename + +from qhana_plugin_runner.storage import STORE + +TASK_LOGGER = get_task_logger(__name__) + +def get_readable_hash(s: str) -> str: + return muid.pretty(muid.bhash(s.encode("utf-8")), k1=6, k2=5).replace(" ", "-") + +@CELERY.task(name=f"{BarDiagram.instance.identifier}.visualization_task", bind=True) +def visualization_task(self, db_id: int) -> str: + import pandas as pd + import plotly.express as px + + TASK_LOGGER.info(f"Starting new demo task with db id '{db_id}'") + task_data: Optional[ProcessingTask] = ProcessingTask.get_by_id(id_=db_id) + + if task_data is None: + msg = f"Could not load task data with id {db_id} to read parameters!" + TASK_LOGGER.error(msg) + raise KeyError(msg) + + input_params: InputParameters = InputParametersSchema().loads(task_data.parameters) + + clusters_url = input_params.clusters_url + TASK_LOGGER.info( + f"Loaded input parameters from db: clusters_url='{clusters_url}'" + ) + + # load data from file + + clusters = open_url(clusters_url).json() + + cluster_list = [] + amount_list = [] + + for cl in clusters: + label = cl["label"] + if not label in cluster_list: + cluster_list.append(label) + amount_list.append(1) + else: + amount_list[cluster_list.index(label)] += 1 + + df = pd.DataFrame( + { + "Cluster ID": cluster_list, + "Amount": amount_list, + } + ) + + fig = px.bar(df, x='Cluster ID', y='Amount', color='Cluster ID') + fig.update_coloraxes(showscale=False) + + filenames_hash = get_readable_hash(retrieve_filename(clusters_url)) + + info_str = f"_bar-diagram_{filenames_hash}" + + with SpooledTemporaryFile(mode="wt") as output: + html = fig.to_html(include_plotlyjs='cdn') + output.write(html) + + STORE.persist_task_result( + db_id, + output, + f"plot{info_str}.html", + "plot", + "text/html", + ) + + return "Result stored in file" From 5300614b772bb14f314e266d56ac1d9340471220 Mon Sep 17 00:00:00 2001 From: B3at Date: Mon, 19 Aug 2024 22:23:56 +0200 Subject: [PATCH 2/7] Bar Diagram Lint Tried to change formatting so Lint accepts it --- plugins/bar_simple/__init__.py | 1 + plugins/bar_simple/routes.py | 17 +++++++++-------- plugins/bar_simple/schemas.py | 14 +++++++------- plugins/bar_simple/tasks.py | 10 +++++----- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/plugins/bar_simple/__init__.py b/plugins/bar_simple/__init__.py index a2550eebd..0e09cbb53 100644 --- a/plugins/bar_simple/__init__.py +++ b/plugins/bar_simple/__init__.py @@ -24,6 +24,7 @@ _identifier = plugin_identifier(_plugin_name, __version__) + BAR_BLP = SecurityBlueprint( _identifier, # blueprint name __name__, # module import name! diff --git a/plugins/bar_simple/routes.py b/plugins/bar_simple/routes.py index bca50e4f6..2a45d6baa 100644 --- a/plugins/bar_simple/routes.py +++ b/plugins/bar_simple/routes.py @@ -31,7 +31,7 @@ PluginMetadataSchema, PluginType, InputDataMetadata, - DataMetadata + DataMetadata, ) from qhana_plugin_runner.db.models.tasks import ProcessingTask from qhana_plugin_runner.tasks import ( @@ -45,6 +45,7 @@ from .schemas import InputParametersSchema, TaskResponseSchema from .tasks import visualization_task + @BAR_BLP.route("/") class PluginsView(MethodView): """Plugins collection resource.""" @@ -73,9 +74,7 @@ def get(self): ], data_output=[ DataMetadata( - data_type="plot", - content_type=["text/html"], - required=True + data_type="plot", content_type=["text/html"], required=True ) ], ), @@ -152,13 +151,15 @@ class ProcessView(MethodView): def post(self, arguments): """Start the visualization task.""" db_task = ProcessingTask( - task_name=visualization_task.name, - parameters=InputParametersSchema().dumps(arguments) - ) + task_name=visualization_task.name, + parameters=InputParametersSchema().dumps(arguments), + ) db_task.save(commit=True) # all tasks need to know about db id to load the db entry - task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s(db_id=db_task.id) + task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s( + db_id=db_task.id + ) # save errors to db task.link_error(save_task_error.s(db_id=db_task.id)) task.apply_async() diff --git a/plugins/bar_simple/schemas.py b/plugins/bar_simple/schemas.py index 716274fb8..320411690 100644 --- a/plugins/bar_simple/schemas.py +++ b/plugins/bar_simple/schemas.py @@ -14,25 +14,24 @@ from marshmallow import post_load import marshmallow as ma -from qhana_plugin_runner.api.util import ( - FrontendFormBaseSchema, - MaBaseSchema, - FileUrl -) +from qhana_plugin_runner.api.util import FrontendFormBaseSchema, MaBaseSchema, FileUrl from dataclasses import dataclass + class TaskResponseSchema(MaBaseSchema): name = ma.fields.String(required=True, allow_none=False, dump_only=True) task_id = ma.fields.String(required=True, allow_none=False, dump_only=True) task_result_url = ma.fields.Url(required=True, allow_none=False, dump_only=True) + @dataclass(repr=False) class InputParameters: - clusters_url : str + clusters_url: str def __str__(self): return str(self.__dict__) + class InputParametersSchema(FrontendFormBaseSchema): clusters_url = FileUrl( required=True, @@ -48,4 +47,5 @@ class InputParametersSchema(FrontendFormBaseSchema): @post_load def make_input_params(self, data, **kwargs) -> InputParameters: - return InputParameters(**data) \ No newline at end of file + return InputParameters(**data) + \ No newline at end of file diff --git a/plugins/bar_simple/tasks.py b/plugins/bar_simple/tasks.py index 0963a7881..e8bc8b6d1 100644 --- a/plugins/bar_simple/tasks.py +++ b/plugins/bar_simple/tasks.py @@ -31,9 +31,11 @@ TASK_LOGGER = get_task_logger(__name__) + def get_readable_hash(s: str) -> str: return muid.pretty(muid.bhash(s.encode("utf-8")), k1=6, k2=5).replace(" ", "-") + @CELERY.task(name=f"{BarDiagram.instance.identifier}.visualization_task", bind=True) def visualization_task(self, db_id: int) -> str: import pandas as pd @@ -50,9 +52,7 @@ def visualization_task(self, db_id: int) -> str: input_params: InputParameters = InputParametersSchema().loads(task_data.parameters) clusters_url = input_params.clusters_url - TASK_LOGGER.info( - f"Loaded input parameters from db: clusters_url='{clusters_url}'" - ) + TASK_LOGGER.info(f"Loaded input parameters from db: clusters_url='{clusters_url}'") # load data from file @@ -76,7 +76,7 @@ def visualization_task(self, db_id: int) -> str: } ) - fig = px.bar(df, x='Cluster ID', y='Amount', color='Cluster ID') + fig = px.bar(df, x="Cluster ID", y="Amount", color="Cluster ID") fig.update_coloraxes(showscale=False) filenames_hash = get_readable_hash(retrieve_filename(clusters_url)) @@ -84,7 +84,7 @@ def visualization_task(self, db_id: int) -> str: info_str = f"_bar-diagram_{filenames_hash}" with SpooledTemporaryFile(mode="wt") as output: - html = fig.to_html(include_plotlyjs='cdn') + html = fig.to_html(include_plotlyjs="cdn") output.write(html) STORE.persist_task_result( From 66a110cf01c00845b119724ae244cd565fc4499a Mon Sep 17 00:00:00 2001 From: B3at Date: Mon, 19 Aug 2024 22:27:36 +0200 Subject: [PATCH 3/7] Lint try 2 --- plugins/bar_simple/__init__.py | 4 ++-- plugins/bar_simple/routes.py | 6 +++--- plugins/bar_simple/schemas.py | 1 - plugins/bar_simple/tasks.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/plugins/bar_simple/__init__.py b/plugins/bar_simple/__init__.py index 0e09cbb53..554aadca4 100644 --- a/plugins/bar_simple/__init__.py +++ b/plugins/bar_simple/__init__.py @@ -24,13 +24,13 @@ _identifier = plugin_identifier(_plugin_name, __version__) - BAR_BLP = SecurityBlueprint( _identifier, # blueprint name __name__, # module import name! description="Simple Bar Diagram API.", ) + class BarDiagram(QHAnaPluginBase): name = _plugin_name version = __version__ @@ -42,7 +42,7 @@ def __init__(self, app: Optional[Flask]) -> None: def get_api_blueprint(self): return BAR_BLP - + def get_requirements(self) -> str: return "plotly~=5.18.0\npandas~=1.5.0\nmuid~=0.5.3" diff --git a/plugins/bar_simple/routes.py b/plugins/bar_simple/routes.py index 2a45d6baa..fdaabb9c6 100644 --- a/plugins/bar_simple/routes.py +++ b/plugins/bar_simple/routes.py @@ -54,7 +54,7 @@ class PluginsView(MethodView): @BAR_BLP.require_jwt("jwt", optional=True) def get(self): """Endpoint returning the plugin metadata.""" - + return PluginMetadata( title="Bar Diagram", description=BarDiagram.instance.description, @@ -118,7 +118,7 @@ def post(self, errors): def render(self, data: Mapping, errors: dict): schema = InputParametersSchema() - + data_dict = dict(data) # define default values @@ -127,7 +127,7 @@ def render(self, data: Mapping, errors: dict): # overwrite default values with other values if possible default_values.update(data_dict) data_dict = default_values - + return Response( render_template( "simple_template.html", diff --git a/plugins/bar_simple/schemas.py b/plugins/bar_simple/schemas.py index 320411690..cec196a68 100644 --- a/plugins/bar_simple/schemas.py +++ b/plugins/bar_simple/schemas.py @@ -48,4 +48,3 @@ class InputParametersSchema(FrontendFormBaseSchema): @post_load def make_input_params(self, data, **kwargs) -> InputParameters: return InputParameters(**data) - \ No newline at end of file diff --git a/plugins/bar_simple/tasks.py b/plugins/bar_simple/tasks.py index e8bc8b6d1..2b5a641ed 100644 --- a/plugins/bar_simple/tasks.py +++ b/plugins/bar_simple/tasks.py @@ -48,7 +48,7 @@ def visualization_task(self, db_id: int) -> str: msg = f"Could not load task data with id {db_id} to read parameters!" TASK_LOGGER.error(msg) raise KeyError(msg) - + input_params: InputParameters = InputParametersSchema().loads(task_data.parameters) clusters_url = input_params.clusters_url From bdd8308e08301606be0315991368a595e4f92874 Mon Sep 17 00:00:00 2001 From: B3at Date: Tue, 20 Aug 2024 20:44:58 +0200 Subject: [PATCH 4/7] Added Cluster SVM Plugin Can currently visualize already calculated clusters in 2D and calculate and add SVM lines to them. 3D in development --- plugins/bar_simple/routes.py | 8 +- plugins/bar_simple/tasks.py | 1 - plugins/cluster_svm_visualization/__init__.py | 57 ++++++ plugins/cluster_svm_visualization/routes.py | 176 ++++++++++++++++++ plugins/cluster_svm_visualization/schemas.py | 83 +++++++++ plugins/cluster_svm_visualization/tasks.py | 158 ++++++++++++++++ .../scikit_ml/classical_k_means/tasks.py | 2 +- 7 files changed, 477 insertions(+), 8 deletions(-) create mode 100644 plugins/cluster_svm_visualization/__init__.py create mode 100644 plugins/cluster_svm_visualization/routes.py create mode 100644 plugins/cluster_svm_visualization/schemas.py create mode 100644 plugins/cluster_svm_visualization/tasks.py diff --git a/plugins/bar_simple/routes.py b/plugins/bar_simple/routes.py index fdaabb9c6..34a6f0021 100644 --- a/plugins/bar_simple/routes.py +++ b/plugins/bar_simple/routes.py @@ -13,13 +13,11 @@ # limitations under the License. from http import HTTPStatus -from json import dumps from typing import Mapping from celery.canvas import chain -from celery.utils.log import get_task_logger -from flask import Response, redirect, abort -from flask.globals import current_app, request +from flask import Response, redirect +from flask.globals import request from flask.helpers import url_for from flask.templating import render_template from flask.views import MethodView @@ -35,8 +33,6 @@ ) from qhana_plugin_runner.db.models.tasks import ProcessingTask from qhana_plugin_runner.tasks import ( - TASK_STEPS_CHANGED, - add_step, save_task_error, save_task_result, ) diff --git a/plugins/bar_simple/tasks.py b/plugins/bar_simple/tasks.py index 2b5a641ed..9292daf1d 100644 --- a/plugins/bar_simple/tasks.py +++ b/plugins/bar_simple/tasks.py @@ -16,7 +16,6 @@ from tempfile import SpooledTemporaryFile from typing import Optional -from json import loads from celery.utils.log import get_task_logger diff --git a/plugins/cluster_svm_visualization/__init__.py b/plugins/cluster_svm_visualization/__init__.py new file mode 100644 index 000000000..c4c79cdd5 --- /dev/null +++ b/plugins/cluster_svm_visualization/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2022 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from flask import Flask + +from qhana_plugin_runner.api.util import SecurityBlueprint +from qhana_plugin_runner.util.plugins import plugin_identifier, QHAnaPluginBase + +_plugin_name = "cluster-svm-visualization" +__version__ = "v0.0.1" +_identifier = plugin_identifier(_plugin_name, __version__) + + +VIS_BLP = SecurityBlueprint( + _identifier, # blueprint name + __name__, # module import name! + description="Cluster Visualization API with added SVM calculation.", +) + + +class ClusterSVM(QHAnaPluginBase): + name = _plugin_name + version = __version__ + description = "Prints provided Data into a Scatter Plot with SVM." + tags = ["cluster","SVM", "visualization"] + + def __init__(self, app: Optional[Flask]) -> None: + super().__init__(app) + + def get_api_blueprint(self): + return VIS_BLP + + def get_requirements(self) -> str: + return "plotly~=5.18.0\npandas~=1.5.0\nmuid~=0.5.3\nscikit-learn~=1.1" + + +try: + # It is important to import the routes **after** COSTUME_LOADER_BLP and CostumeLoader are defined, because they are + # accessed as soon as the routes are imported. + from . import routes +except ImportError: + # When running `poetry run flask install`, importing the routes will fail, because the dependencies are not + # installed yet. + pass diff --git a/plugins/cluster_svm_visualization/routes.py b/plugins/cluster_svm_visualization/routes.py new file mode 100644 index 000000000..d646840cf --- /dev/null +++ b/plugins/cluster_svm_visualization/routes.py @@ -0,0 +1,176 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Mapping + +from celery.canvas import chain +from flask import Response, redirect +from flask.globals import request +from flask.helpers import url_for +from flask.templating import render_template +from flask.views import MethodView +from marshmallow import EXCLUDE + +from qhana_plugin_runner.api.plugin_schemas import ( + EntryPoint, + PluginMetadata, + PluginMetadataSchema, + PluginType, + InputDataMetadata, + DataMetadata, +) +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.tasks import ( + save_task_error, + save_task_result, +) + +from . import VIS_BLP, ClusterSVM +from .schemas import InputParametersSchema, TaskResponseSchema +from .tasks import visualization_task + + +@VIS_BLP.route("/") +class PluginsView(MethodView): + """Plugins collection resource.""" + + @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self): + """Endpoint returning the plugin metadata.""" + + return PluginMetadata( + title="Cluster SVM Visualization", + description=ClusterSVM.instance.description, + name=ClusterSVM.instance.name, + version=ClusterSVM.instance.version, + type=PluginType.visualization, + entry_point=EntryPoint( + href=url_for(f"{VIS_BLP.name}.ProcessView"), + ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"), + data_input=[ + InputDataMetadata( + data_type="entity/vector", + content_type=["application/json"], + required=True, + parameter="entityUrl", + ), + InputDataMetadata( + data_type="entity/label", + content_type=["application/json"], + required=True, + parameter="clustersUrl", + ), + ], + data_output=[ + DataMetadata( + data_type="plot", content_type=["text/html"], required=True + ), + DataMetadata( + data_type="plot3d", content_type=["text/html"], required=False + ), + ], + ), + tags=ClusterSVM.instance.tags, + ) + + +@VIS_BLP.route("/ui/") +class MicroFrontend(MethodView): + """Micro frontend for the Cluster SVM Visualization plugin.""" + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for the Cluster SVM Visualization plugin." + ) + @VIS_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="query", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self, errors): + """Return the micro frontend.""" + return self.render(request.args, errors) + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for the Cluster SVM Visualization plugin." + ) + @VIS_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="form", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, errors): + """Return the micro frontend with prerendered inputs.""" + return self.render(request.form, errors) + + def render(self, data: Mapping, errors: dict): + schema = InputParametersSchema() + + data_dict = dict(data) + + # define default values + default_values = {} + + # overwrite default values with other values if possible + default_values.update(data_dict) + data_dict = default_values + + return Response( + render_template( + "simple_template.html", + name=ClusterSVM.instance.name, + version=ClusterSVM.instance.version, + schema=schema, + values=data_dict, + errors=errors, + process=url_for(f"{VIS_BLP.name}.ProcessView"), + ) + ) + + +@VIS_BLP.route("/process/") +class ProcessView(MethodView): + """Start a long running processing task.""" + + @VIS_BLP.arguments(InputParametersSchema(unknown=EXCLUDE), location="form") + @VIS_BLP.response(HTTPStatus.OK, TaskResponseSchema()) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, arguments): + """Start the visualization task.""" + db_task = ProcessingTask( + task_name=visualization_task.name, + parameters=InputParametersSchema().dumps(arguments), + ) + db_task.save(commit=True) + + # all tasks need to know about db id to load the db entry + task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s( + db_id=db_task.id + ) + # save errors to db + task.link_error(save_task_error.s(db_id=db_task.id)) + task.apply_async() + + db_task.save(commit=True) + + return redirect( + url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER + ) diff --git a/plugins/cluster_svm_visualization/schemas.py b/plugins/cluster_svm_visualization/schemas.py new file mode 100644 index 000000000..0d628c72e --- /dev/null +++ b/plugins/cluster_svm_visualization/schemas.py @@ -0,0 +1,83 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from marshmallow import post_load +import marshmallow as ma +from qhana_plugin_runner.api.util import FrontendFormBaseSchema, MaBaseSchema, FileUrl +from dataclasses import dataclass + + +class TaskResponseSchema(MaBaseSchema): + name = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_id = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_result_url = ma.fields.Url(required=True, allow_none=False, dump_only=True) + + +@dataclass(repr=False) +class InputParameters: + entity_url: str + clusters_url: str + do_svm: bool = False + do_3d: bool = False + + def __str__(self): + return str(self.__dict__) + + +class InputParametersSchema(FrontendFormBaseSchema): + entity_url = FileUrl( + required=True, + allow_none=False, + data_input_type="entity/vector", + data_content_types=["application/json"], + metadata={ + "label": "Entity Point URL", + "description": "URL to a json file containing the points.", + "input_type": "text", + }, + ) + clusters_url = FileUrl( + required=True, + allow_none=False, + data_input_type="entity/label", + data_content_types=["application/json"], + metadata={ + "label": "Cluster points URL", + "description": "URL to a json file containing the cluster points.", + "input_type": "text", + }, + ) + do_svm = ma.fields.Boolean( + required=False, + allow_none=False, + metadata={ + "label": "SVM", + "description": "Calculate and plot Support Vector Machine.", + "input_type": "checkbox", + }, + ) + do_3d = ma.fields.Boolean( + required=False, + allow_none=False, + metadata={ + "label": "3D", + "description": "Plot the Data additionally in 3D.", + "input_type": "checkbox", + }, + ) + + + @post_load + def make_input_params(self, data, **kwargs) -> InputParameters: + return InputParameters(**data) diff --git a/plugins/cluster_svm_visualization/tasks.py b/plugins/cluster_svm_visualization/tasks.py new file mode 100644 index 000000000..2783ce1c4 --- /dev/null +++ b/plugins/cluster_svm_visualization/tasks.py @@ -0,0 +1,158 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tempfile import SpooledTemporaryFile + +from typing import Optional + +from celery.utils.log import get_task_logger + +import muid +from sklearn import svm +from . import ClusterSVM +from .schemas import InputParameters, InputParametersSchema +from qhana_plugin_runner.celery import CELERY +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.requests import open_url, retrieve_filename + +from qhana_plugin_runner.storage import STORE + +TASK_LOGGER = get_task_logger(__name__) + + +def get_readable_hash(s: str) -> str: + return muid.pretty(muid.bhash(s.encode("utf-8")), k1=6, k2=5).replace(" ", "-") + + +@CELERY.task(name=f"{ClusterSVM.instance.identifier}.visualization_task", bind=True) +def visualization_task(self, db_id: int) -> str: + import pandas as pd + import plotly.express as px + import plotly.graph_objects as go + + TASK_LOGGER.info(f"Starting new demo task with db id '{db_id}'") + task_data: Optional[ProcessingTask] = ProcessingTask.get_by_id(id_=db_id) + + if task_data is None: + msg = f"Could not load task data with id {db_id} to read parameters!" + TASK_LOGGER.error(msg) + raise KeyError(msg) + + input_params: InputParameters = InputParametersSchema().loads(task_data.parameters) + + entity_url = input_params.entity_url + clusters_url = input_params.clusters_url + do_svm = input_params.do_svm + do_3d = input_params.do_3d + TASK_LOGGER.info(f"Loaded input parameters from db: '{str(input_params)}'") + + # load data from file + entity_points = open_url(entity_url).json() + clusters = open_url(clusters_url).json() + + print(entity_url) + + pt_x_list = [0 for _ in range(0, len(entity_points))] + pt_y_list = [0 for _ in range(0, len(entity_points))] + label_list = [0 for _ in range(0, len(entity_points))] + id_list = [x for x in range(0, len(entity_points))] + size_list = [10 for _ in range(0, len(entity_points))] + max_cluster = 0 + + for pt in entity_points: + idx = int(pt["ID"]) + pt_x_list[idx] = pt["dim0"] + pt_y_list[idx] = pt["dim1"] + + for cl in clusters: + label_list[int(cl["ID"])] = cl["label"] + max_cluster = max(max_cluster, cl["label"]) + + df = pd.DataFrame( + { + "ID": [f"Point {x}" for x in id_list], + "x": pt_x_list, + "y": pt_y_list, + "Cluster ID": [str(x) for x in label_list], + "size": size_list, + } + ) + + fig = px.scatter( + df, + x="x", + y="y", + size="size", + hover_name="ID", + color="Cluster ID", + hover_data={"size":False} + ) + + if do_svm: + cluster_list = [[] for _ in range(0, max_cluster+1)] + for idx, label in enumerate(label_list): + cluster_list[label].append([pt_x_list[idx], pt_y_list[idx]]) + + for i, cl1 in enumerate(cluster_list): + for j, cl2 in enumerate(cluster_list): + if i >= j: + continue + + cluster_label = [i for _ in range(0, len(cl1))] + cluster_label = cluster_label + [j for _ in range(0, len(cl2))] + cl = cl1 + cl2 + + clf = svm.SVC(kernel='linear') + clf.fit(cl, cluster_label) + + a = -clf.coef_[0][0] / clf.coef_[0][1] + b = clf.intercept_[0] / clf.coef_[0][1] + + x_range = [min(pt_x_list), max(pt_x_list)] + + fig.add_trace( + go.Scatter( + x=x_range, + y=[a * x - b for x in x_range], + mode='lines', + name=f'SVM for cluster {i} and {j}', + hoveron="fills", + ) + ) + + y_max = max(pt_y_list) + y_min = min(pt_y_list) + padding = (y_max - y_min) * 0.05 + fig.update_layout(yaxis=dict(range=[y_min - padding, y_max + padding])) + + # fig.update_layout(showlegend=False) + + filenames_hash = get_readable_hash(retrieve_filename(clusters_url)) + + info_str = f"_cluster-svm_{filenames_hash}" + + with SpooledTemporaryFile(mode="wt") as output: + html = fig.to_html(include_plotlyjs="cdn") + output.write(html) + + STORE.persist_task_result( + db_id, + output, + f"plot{info_str}.html", + "plot", + "text/html", + ) + + return "Result stored in file" diff --git a/stable_plugins/classical_ml/scikit_ml/classical_k_means/tasks.py b/stable_plugins/classical_ml/scikit_ml/classical_k_means/tasks.py index 5d1b43cc8..7003bace9 100644 --- a/stable_plugins/classical_ml/scikit_ml/classical_k_means/tasks.py +++ b/stable_plugins/classical_ml/scikit_ml/classical_k_means/tasks.py @@ -100,7 +100,7 @@ def calculation_task(self, db_id: int) -> str: if fig is not None: with SpooledTemporaryFile(mode="wt") as output: - html = fig.to_html() + html = fig.to_html(include_plotlyjs="cdn") output.write(html) STORE.persist_task_result( From 8eaf9527f4665166101c8cb064a157b98073f0ea Mon Sep 17 00:00:00 2001 From: B3at Date: Tue, 20 Aug 2024 20:49:41 +0200 Subject: [PATCH 5/7] Cluster SVM Lint Fix --- plugins/cluster_svm_visualization/__init__.py | 2 +- plugins/cluster_svm_visualization/routes.py | 6 ++++-- plugins/cluster_svm_visualization/schemas.py | 1 - plugins/cluster_svm_visualization/tasks.py | 16 ++++++++-------- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/plugins/cluster_svm_visualization/__init__.py b/plugins/cluster_svm_visualization/__init__.py index c4c79cdd5..9e01e8214 100644 --- a/plugins/cluster_svm_visualization/__init__.py +++ b/plugins/cluster_svm_visualization/__init__.py @@ -35,7 +35,7 @@ class ClusterSVM(QHAnaPluginBase): name = _plugin_name version = __version__ description = "Prints provided Data into a Scatter Plot with SVM." - tags = ["cluster","SVM", "visualization"] + tags = ["cluster", "SVM", "visualization"] def __init__(self, app: Optional[Flask]) -> None: super().__init__(app) diff --git a/plugins/cluster_svm_visualization/routes.py b/plugins/cluster_svm_visualization/routes.py index d646840cf..61e442582 100644 --- a/plugins/cluster_svm_visualization/routes.py +++ b/plugins/cluster_svm_visualization/routes.py @@ -92,7 +92,8 @@ class MicroFrontend(MethodView): """Micro frontend for the Cluster SVM Visualization plugin.""" @VIS_BLP.html_response( - HTTPStatus.OK, description="Micro frontend for the Cluster SVM Visualization plugin." + HTTPStatus.OK, + description="Micro frontend for the Cluster SVM Visualization plugin.", ) @VIS_BLP.arguments( InputParametersSchema( @@ -107,7 +108,8 @@ def get(self, errors): return self.render(request.args, errors) @VIS_BLP.html_response( - HTTPStatus.OK, description="Micro frontend for the Cluster SVM Visualization plugin." + HTTPStatus.OK, + description="Micro frontend for the Cluster SVM Visualization plugin.", ) @VIS_BLP.arguments( InputParametersSchema( diff --git a/plugins/cluster_svm_visualization/schemas.py b/plugins/cluster_svm_visualization/schemas.py index 0d628c72e..8f170f23c 100644 --- a/plugins/cluster_svm_visualization/schemas.py +++ b/plugins/cluster_svm_visualization/schemas.py @@ -77,7 +77,6 @@ class InputParametersSchema(FrontendFormBaseSchema): }, ) - @post_load def make_input_params(self, data, **kwargs) -> InputParameters: return InputParameters(**data) diff --git a/plugins/cluster_svm_visualization/tasks.py b/plugins/cluster_svm_visualization/tasks.py index 2783ce1c4..99e8822dd 100644 --- a/plugins/cluster_svm_visualization/tasks.py +++ b/plugins/cluster_svm_visualization/tasks.py @@ -97,13 +97,13 @@ def visualization_task(self, db_id: int) -> str: size="size", hover_name="ID", color="Cluster ID", - hover_data={"size":False} - ) + hover_data={"size":False}, + ) if do_svm: - cluster_list = [[] for _ in range(0, max_cluster+1)] + cluster_list = [[] for _ in range(0, max_cluster + 1)] for idx, label in enumerate(label_list): - cluster_list[label].append([pt_x_list[idx], pt_y_list[idx]]) + cluster_list[label].append([pt_x_list[idx], pt_y_list[idx]]) for i, cl1 in enumerate(cluster_list): for j, cl2 in enumerate(cluster_list): @@ -114,9 +114,9 @@ def visualization_task(self, db_id: int) -> str: cluster_label = cluster_label + [j for _ in range(0, len(cl2))] cl = cl1 + cl2 - clf = svm.SVC(kernel='linear') + clf = svm.SVC(kernel="linear") clf.fit(cl, cluster_label) - + a = -clf.coef_[0][0] / clf.coef_[0][1] b = clf.intercept_[0] / clf.coef_[0][1] @@ -126,8 +126,8 @@ def visualization_task(self, db_id: int) -> str: go.Scatter( x=x_range, y=[a * x - b for x in x_range], - mode='lines', - name=f'SVM for cluster {i} and {j}', + mode="lines", + name=f"SVM for cluster {i} and {j}", hoveron="fills", ) ) From f513eac8ddd3f92133652e136896438c896503cf Mon Sep 17 00:00:00 2001 From: B3at Date: Tue, 20 Aug 2024 20:50:38 +0200 Subject: [PATCH 6/7] Cluster SVM Lint Fix 2 --- plugins/cluster_svm_visualization/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/cluster_svm_visualization/tasks.py b/plugins/cluster_svm_visualization/tasks.py index 99e8822dd..5190cf6c2 100644 --- a/plugins/cluster_svm_visualization/tasks.py +++ b/plugins/cluster_svm_visualization/tasks.py @@ -97,7 +97,7 @@ def visualization_task(self, db_id: int) -> str: size="size", hover_name="ID", color="Cluster ID", - hover_data={"size":False}, + hover_data={"size": False}, ) if do_svm: From ea966149616aaf8eff2b86900812c9169c1fe08a Mon Sep 17 00:00:00 2001 From: B3at Date: Wed, 9 Oct 2024 15:16:10 +0200 Subject: [PATCH 7/7] Data Creator 3D update and ZXCalculus added Also work began on Cluster Scatter Visualization Plugin --- .../cluster_scatter_visualization.py | 350 ++++++++++++++++++ .../cluster_scatter_visualization.html | 126 +++++++ plugins/cluster_svm_visualization/tasks.py | 34 +- plugins/zxcalculus/__init__.py | 57 +++ plugins/zxcalculus/routes.py | 160 ++++++++ plugins/zxcalculus/schemas.py | 68 ++++ plugins/zxcalculus/tasks.py | 79 ++++ .../data_synthesis/data_creator/__init__.py | 2 +- .../data_creator/backend/datasets.py | 46 +++ .../data_synthesis/data_creator/routes.py | 1 + .../data_synthesis/data_creator/schemas.py | 10 + .../templates/data_creator_template.html | 5 + 12 files changed, 928 insertions(+), 10 deletions(-) create mode 100644 plugins/cluster_scatter_visualization/cluster_scatter_visualization.py create mode 100644 plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html create mode 100644 plugins/zxcalculus/__init__.py create mode 100644 plugins/zxcalculus/routes.py create mode 100644 plugins/zxcalculus/schemas.py create mode 100644 plugins/zxcalculus/tasks.py diff --git a/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py b/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py new file mode 100644 index 000000000..75ce158fb --- /dev/null +++ b/plugins/cluster_scatter_visualization/cluster_scatter_visualization.py @@ -0,0 +1,350 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from http import HTTPStatus +from io import BytesIO +import pathlib +from tempfile import SpooledTemporaryFile +from typing import Mapping, Optional +from celery import chain +import celery +from celery.utils.log import get_task_logger +from flask import abort, redirect, send_file +from flask.app import Flask +from flask.globals import request +from flask.helpers import url_for +from flask.templating import render_template +from flask.views import MethodView +from flask.wrappers import Response +from marshmallow import EXCLUDE +from requests.exceptions import HTTPError + +from qhana_plugin_runner.api.plugin_schemas import ( + DataMetadata, + EntryPoint, + InputDataMetadata, + PluginMetadata, + PluginMetadataSchema, + PluginType, +) +from qhana_plugin_runner.api.util import ( + FileUrl, + FrontendFormBaseSchema, + SecurityBlueprint, +) +from qhana_plugin_runner.celery import CELERY +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.requests import open_url +from qhana_plugin_runner.storage import STORE +from qhana_plugin_runner.tasks import save_task_error, save_task_result +from qhana_plugin_runner.util.plugins import QHAnaPluginBase, plugin_identifier +from qhana_plugin_runner.db.models.virtual_plugins import DataBlob, PluginState + +_plugin_name = "cluster-scatter-visualization" +__version__ = "v0.0.1" +_identifier = plugin_identifier(_plugin_name, __version__) + + +VIS_BLP = SecurityBlueprint( + _identifier, # blueprint name + __name__, # module import name! + description="A visualization plugin for cluster scatter data.", + template_folder="cluster_scatter_visualization_templates", +) + + +class CSInputParametersSchema(FrontendFormBaseSchema): + entity_url = FileUrl( + required=True, + allow_none=False, + data_input_type="entity/vector", + data_content_types=["application/json"], + metadata={ + "label": "Entity Point URL", + "description": "URL to a json file containing the points.", + "input_type": "text", + }, + ) + clusters_url = FileUrl( + required=True, + allow_none=True, + data_input_type="entity/vector", + data_content_types=["application/json"], + metadata={ + "label": "Cluster URL", + "description": "URL to a json file containing the cluster labels.", + "input_type": "text", + }, + ) + + +@VIS_BLP.route("/") +class PluginsView(MethodView): + """Plugins collection resource.""" + + @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema()) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self): + """Endpoint returning the plugin metadata.""" + plugin = CSVisualization.instance + if plugin is None: + abort(HTTPStatus.INTERNAL_SERVER_ERROR) + return PluginMetadata( + title=plugin.name, + description=plugin.description, + name=plugin.name, + version=plugin.version, + type=PluginType.visualization, + entry_point=EntryPoint( + href=url_for(f"{VIS_BLP.name}.ProcessView"), + ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"), + plugin_dependencies=[], + data_input=[ + InputDataMetadata( + data_type="entity/vector", + content_type=["application/json"], + required=True, + parameter="entityUrl", + ), + InputDataMetadata( + data_type="entity/label", + content_type=["application/json"], + required=True, + parameter="clustersUrl", + ) + ], + data_output=[ + DataMetadata( + data_type="plot", + content_type=["image/svg+xml"], + required=True, + ) + ], + ), + tags=["visualization", "cluster", "scatter"], + ) + + +@VIS_BLP.route("/ui/") +class MicroFrontend(MethodView): + """Micro frontend for the cluster scatter visualization plugin.""" + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend of the cluster scatter visualization plugin." + ) + @VIS_BLP.arguments( + CSInputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="query", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self, errors): + """Return the micro frontend.""" + return self.render(request.args, errors, False) + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend of the cluster scatter visualization plugin." + ) + @VIS_BLP.arguments( + CSInputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="form", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, errors): + """Return the micro frontend with prerendered inputs.""" + return self.render(request.form, errors, not errors) + + def render(self, data: Mapping, errors: dict, valid: bool): + plugin = CSVisualization.instance + if plugin is None: + abort(HTTPStatus.INTERNAL_SERVER_ERROR) + return Response( + render_template( + "cluster_scatter_visualization.html", + name=plugin.name, + version=plugin.version, + schema=CSInputParametersSchema(), + valid=valid, + values=data, + errors=errors, + example_values=url_for(f"{VIS_BLP.name}.MicroFrontend"), + get_circuit_image_url=url_for(f"{VIS_BLP.name}.get_circuit_image"), + process=url_for(f"{VIS_BLP.name}.ProcessView"), + ) + ) + + +class ImageNotFinishedError(Exception): + pass + + +@VIS_BLP.route("/circuits/") +@VIS_BLP.response(HTTPStatus.OK, description="Circuit image.") +@VIS_BLP.arguments( + CSInputParametersSchema(unknown=EXCLUDE), + location="query", + required=True, +) +@VIS_BLP.require_jwt("jwt", optional=True) +def get_circuit_image(data: Mapping): + url = data.get("data", None) + if not url: + abort(HTTPStatus.BAD_REQUEST) + url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest() + image = DataBlob.get_value(CSVisualization.instance.identifier, url_hash, None) + if image is None: + if not ( + task_id := PluginState.get_value( + CSVisualization.instance.identifier, url_hash, None + ) + ): + task_result = generate_image.s(url, url_hash).apply_async() + PluginState.set_value( + CSVisualization.instance.identifier, + url_hash, + task_result.id, + commit=True, + ) + else: + task_result = CELERY.AsyncResult(task_id) + try: + task_result.get(timeout=5) + image = DataBlob.get_value(CSVisualization.instance.identifier, url_hash) + except celery.exceptions.TimeoutError: + return Response("Image not yet created!", HTTPStatus.ACCEPTED) + if not image: + abort(HTTPStatus.BAD_REQUEST, "Invalid circuit URL!") + return send_file(BytesIO(image), mimetype="image/svg+xml") + + +@VIS_BLP.route("/process/") +class ProcessView(MethodView): + """Start a long running processing task.""" + + @VIS_BLP.arguments(CSInputParametersSchema(unknown=EXCLUDE), location="form") + @VIS_BLP.response(HTTPStatus.SEE_OTHER) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, arguments): + circuit_url = arguments.get("data", None) + if circuit_url is None: + abort(HTTPStatus.BAD_REQUEST) + url_hash = hashlib.sha256(circuit_url.encode("utf-8")).hexdigest() + db_task = ProcessingTask(task_name=process.name) + db_task.save(commit=True) + + # all tasks need to know about db id to load the db entry + task: chain = process.s( + db_id=db_task.id, url=circuit_url, hash=url_hash + ) | save_task_result.s(db_id=db_task.id) + # save errors to db + task.link_error(save_task_error.s(db_id=db_task.id)) + task.apply_async(db_id=db_task.id, url=circuit_url, hash=url_hash) + + db_task.save(commit=True) + + return redirect( + url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER + ) + + +class CSVisualization(QHAnaPluginBase): + name = _plugin_name + version = __version__ + description = "Visualizes cluster data in a scatter plot." + tags = ["visualization"] + + def __init__(self, app: Optional[Flask]) -> None: + super().__init__(app) + + # create folder for circuit images + pathlib.Path(__file__).parent.absolute().joinpath("img").mkdir( + parents=True, exist_ok=True + ) + + def get_api_blueprint(self): + return VIS_BLP + + def get_requirements(self) -> str: + return "pylatexenc~=2.10\nqiskit~=0.43" + + +TASK_LOGGER = get_task_logger(__name__) + + +@CELERY.task(name=f"{CSVisualization.instance.identifier}.generate_image", bind=True) +def generate_image(self, url: str, hash: str) -> str: + from qiskit import QuantumCircuit + import matplotlib + + matplotlib.use("SVG") + + TASK_LOGGER.info(f"Generating image for circuit {url}...") + try: + with open_url(url) as qasm_response: + circuit_qasm = qasm_response.text + except HTTPError: + TASK_LOGGER.error(f"Invalid circuit URL: {url}") + DataBlob.set_value( + CSVisualization.instance.identifier, + hash, + "", + ) + PluginState.delete_value(CSVisualization.instance.identifier, hash, commit=True) + return "Invalid circuit URL!" + + circuit = QuantumCircuit.from_qasm_str(circuit_qasm) + fig = circuit.draw(output="mpl", interactive=False) + figfile = BytesIO() + fig.savefig(figfile, format="svg") + figfile.seek(0) + DataBlob.set_value(CSVisualization.instance.identifier, hash, figfile.getvalue()) + TASK_LOGGER.info(f"Stored image of circuit {circuit.name}.") + PluginState.delete_value(CSVisualization.instance.identifier, hash, commit=True) + + return "Created image of circuit!" + + +@CELERY.task( + name=f"{CSVisualization.instance.identifier}.process", + bind=True, + autoretry_for=(ImageNotFinishedError,), + retry_backoff=True, + max_retries=None, +) +def process(self, db_id: str, url: str, hash: str) -> str: + if not (image := DataBlob.get_value(CSVisualization.instance.identifier, hash)): + if not ( + task_id := PluginState.get_value(CSVisualization.instance.identifier, hash) + ): + task_result = generate_image.s(url, hash).apply_async() + PluginState.set_value( + CSVisualization.instance.identifier, + hash, + task_result.id, + commit=True, + ) + raise ImageNotFinishedError() + with SpooledTemporaryFile() as output: + output.write(image) + output.seek(0) + STORE.persist_task_result( + db_id, output, f"circuit_{hash}.svg", "image/svg", "image/svg+xml" + ) + return "Created image of circuit!" diff --git a/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html b/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html new file mode 100644 index 000000000..5a754fe20 --- /dev/null +++ b/plugins/cluster_scatter_visualization/cluster_scatter_visualization_templates/cluster_scatter_visualization.html @@ -0,0 +1,126 @@ +{% extends "simple_template.html" %} + +{% block head %} +{{ super() }} + +{% endblock head %} + +{% block content %} +
+
+ Visualization Options + {% call forms.render_form(target="microfrontend") %} + {{ forms.render_fields(schema, values=values, errors=errors) }} +
+ {{ forms.submit("validate", target="microfrontend") }} + {{ forms.submit("submit", action=process) }} +
+ {% endcall %} +
+
+ + +
+
+
+{% endblock content %} + +{% block script %} +{{ super() }} + +{% endblock script %} \ No newline at end of file diff --git a/plugins/cluster_svm_visualization/tasks.py b/plugins/cluster_svm_visualization/tasks.py index 5190cf6c2..ebffbc1b4 100644 --- a/plugins/cluster_svm_visualization/tasks.py +++ b/plugins/cluster_svm_visualization/tasks.py @@ -66,6 +66,7 @@ def visualization_task(self, db_id: int) -> str: pt_x_list = [0 for _ in range(0, len(entity_points))] pt_y_list = [0 for _ in range(0, len(entity_points))] + pt_z_list = [0 for _ in range(0, len(entity_points))] label_list = [0 for _ in range(0, len(entity_points))] id_list = [x for x in range(0, len(entity_points))] size_list = [10 for _ in range(0, len(entity_points))] @@ -75,6 +76,8 @@ def visualization_task(self, db_id: int) -> str: idx = int(pt["ID"]) pt_x_list[idx] = pt["dim0"] pt_y_list[idx] = pt["dim1"] + if do_3d: + pt_z_list[idx] = pt["dim2"] for cl in clusters: label_list[int(cl["ID"])] = cl["label"] @@ -85,20 +88,33 @@ def visualization_task(self, db_id: int) -> str: "ID": [f"Point {x}" for x in id_list], "x": pt_x_list, "y": pt_y_list, + "z": pt_z_list, "Cluster ID": [str(x) for x in label_list], "size": size_list, } ) - fig = px.scatter( - df, - x="x", - y="y", - size="size", - hover_name="ID", - color="Cluster ID", - hover_data={"size": False}, - ) + if not do_3d: + fig = px.scatter( + df, + x="x", + y="y", + size="size", + hover_name="ID", + color="Cluster ID", + hover_data={"size": False}, + ) + else: + fig = px.scatter_3d( + df, + x="x", + y="y", + z="z", + size="size", + hover_name="ID", + color="Cluster ID", + hover_data={"size": False}, + ) if do_svm: cluster_list = [[] for _ in range(0, max_cluster + 1)] diff --git a/plugins/zxcalculus/__init__.py b/plugins/zxcalculus/__init__.py new file mode 100644 index 000000000..f083564b2 --- /dev/null +++ b/plugins/zxcalculus/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2022 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from flask import Flask + +from qhana_plugin_runner.api.util import SecurityBlueprint +from qhana_plugin_runner.util.plugins import plugin_identifier, QHAnaPluginBase + +_plugin_name = "zxcalculus" +__version__ = "v0.0.4" +_identifier = plugin_identifier(_plugin_name, __version__) + + +VIS_BLP = SecurityBlueprint( + _identifier, # blueprint name + __name__, # module import name! + description="ZXCalculus API.", +) + + +class ZXCalculus(QHAnaPluginBase): + name = _plugin_name + version = __version__ + description = "Generates a random circuit, visualizes and simplifies it." + tags = ["zxcalculus", "circuit"] + + def __init__(self, app: Optional[Flask]) -> None: + super().__init__(app) + + def get_api_blueprint(self): + return VIS_BLP + + def get_requirements(self) -> str: + return "plotly~=5.18.0\npyzx~=0.8.0\nmpld3~=0.5.10" + + +try: + # It is important to import the routes **after** COSTUME_LOADER_BLP and CostumeLoader are defined, because they are + # accessed as soon as the routes are imported. + from . import routes +except ImportError: + # When running `poetry run flask install`, importing the routes will fail, because the dependencies are not + # installed yet. + pass diff --git a/plugins/zxcalculus/routes.py b/plugins/zxcalculus/routes.py new file mode 100644 index 000000000..003fff128 --- /dev/null +++ b/plugins/zxcalculus/routes.py @@ -0,0 +1,160 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Mapping + +from celery.canvas import chain +from flask import Response, redirect +from flask.globals import request +from flask.helpers import url_for +from flask.templating import render_template +from flask.views import MethodView +from marshmallow import EXCLUDE + +from qhana_plugin_runner.api.plugin_schemas import ( + EntryPoint, + PluginMetadata, + PluginMetadataSchema, + PluginType, + InputDataMetadata, + DataMetadata, +) +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.tasks import ( + save_task_error, + save_task_result, +) + +from . import VIS_BLP, ZXCalculus +from .schemas import InputParametersSchema, TaskResponseSchema +from .tasks import visualization_task + + +@VIS_BLP.route("/") +class PluginsView(MethodView): + """Plugins collection resource.""" + + @VIS_BLP.response(HTTPStatus.OK, PluginMetadataSchema) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self): + """Endpoint returning the plugin metadata.""" + + return PluginMetadata( + title="ZXCalculus", + description=ZXCalculus.instance.description, + name=ZXCalculus.instance.name, + version=ZXCalculus.instance.version, + type=PluginType.visualization, + entry_point=EntryPoint( + href=url_for(f"{VIS_BLP.name}.ProcessView"), + ui_href=url_for(f"{VIS_BLP.name}.MicroFrontend"), + data_input=[], + data_output=[ + DataMetadata( + data_type="circuit", content_type=["text/html"], required=True + ) + ], + ), + tags=ZXCalculus.instance.tags, + ) + + +@VIS_BLP.route("/ui/") +class MicroFrontend(MethodView): + """Micro frontend for the ZXCalculus plugin.""" + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for the ZXCalculus plugin." + ) + @VIS_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="query", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def get(self, errors): + """Return the micro frontend.""" + return self.render(request.args, errors) + + @VIS_BLP.html_response( + HTTPStatus.OK, description="Micro frontend for ZXCalculus plugin." + ) + @VIS_BLP.arguments( + InputParametersSchema( + partial=True, unknown=EXCLUDE, validate_errors_as_result=True + ), + location="form", + required=False, + ) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, errors): + """Return the micro frontend with prerendered inputs.""" + return self.render(request.form, errors) + + def render(self, data: Mapping, errors: dict): + schema = InputParametersSchema() + + data_dict = dict(data) + + # define default values + default_values = {} + + # overwrite default values with other values if possible + default_values.update(data_dict) + data_dict = default_values + + return Response( + render_template( + "simple_template.html", + name=ZXCalculus.instance.name, + version=ZXCalculus.instance.version, + schema=schema, + values=data_dict, + errors=errors, + process=url_for(f"{VIS_BLP.name}.ProcessView"), + ) + ) + + +@VIS_BLP.route("/process/") +class ProcessView(MethodView): + """Start a long running processing task.""" + + @VIS_BLP.arguments(InputParametersSchema(unknown=EXCLUDE), location="form") + @VIS_BLP.response(HTTPStatus.OK, TaskResponseSchema()) + @VIS_BLP.require_jwt("jwt", optional=True) + def post(self, arguments): + """Start the visualization task.""" + db_task = ProcessingTask( + task_name=visualization_task.name, + parameters=InputParametersSchema().dumps(arguments), + ) + db_task.save(commit=True) + + # all tasks need to know about db id to load the db entry + task: chain = visualization_task.s(db_id=db_task.id) | save_task_result.s( + db_id=db_task.id + ) + # save errors to db + task.link_error(save_task_error.s(db_id=db_task.id)) + task.apply_async() + + db_task.save(commit=True) + + return redirect( + url_for("tasks-api.TaskView", task_id=str(db_task.id)), HTTPStatus.SEE_OTHER + ) diff --git a/plugins/zxcalculus/schemas.py b/plugins/zxcalculus/schemas.py new file mode 100644 index 000000000..3870927e8 --- /dev/null +++ b/plugins/zxcalculus/schemas.py @@ -0,0 +1,68 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from marshmallow import post_load +import marshmallow as ma +from qhana_plugin_runner.api.util import FrontendFormBaseSchema, MaBaseSchema, FileUrl +from dataclasses import dataclass + + +class TaskResponseSchema(MaBaseSchema): + name = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_id = ma.fields.String(required=True, allow_none=False, dump_only=True) + task_result_url = ma.fields.Url(required=True, allow_none=False, dump_only=True) + + +@dataclass(repr=False) +class InputParameters: + qubits: int + depth: int + simplify: bool = False + + def __str__(self): + return str(self.__dict__) + + +class InputParametersSchema(FrontendFormBaseSchema): + qubits = ma.fields.Integer( + required=True, + allow_none=False, + metadata={ + "label": "No. Qubits", + "description": "Determines the number of qubits to generate.", + "input_type": "number", + }, + ) + depth = ma.fields.Integer( + required=True, + allow_none=False, + metadata={ + "label": "Depth", + "description": "Determines the depth of the circuits.", + "input_type": "number", + }, + ) + simplify = ma.fields.Boolean( + required=False, + allow_none=False, + metadata={ + "label": "Simplify", + "description": "Simplify the generated circuit.", + "input_type": "checkbox", + } + ) + + @post_load + def make_input_params(self, data, **kwargs) -> InputParameters: + return InputParameters(**data) diff --git a/plugins/zxcalculus/tasks.py b/plugins/zxcalculus/tasks.py new file mode 100644 index 000000000..4ade83ec6 --- /dev/null +++ b/plugins/zxcalculus/tasks.py @@ -0,0 +1,79 @@ +# Copyright 2023 QHAna plugin runner contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tempfile import SpooledTemporaryFile + +from typing import Optional + +from celery.utils.log import get_task_logger + +import muid +from . import ZXCalculus +from .schemas import InputParameters, InputParametersSchema +from qhana_plugin_runner.celery import CELERY +from qhana_plugin_runner.db.models.tasks import ProcessingTask +from qhana_plugin_runner.requests import open_url, retrieve_filename + +from qhana_plugin_runner.storage import STORE + +import pyzx as zx +import matplotlib.pyplot as _, mpld3 + +TASK_LOGGER = get_task_logger(__name__) + + +def get_readable_hash(s: str) -> str: + return muid.pretty(muid.bhash(s.encode("utf-8")), k1=6, k2=5).replace(" ", "-") + + +@CELERY.task(name=f"{ZXCalculus.instance.identifier}.visualization_task", bind=True) +def visualization_task(self, db_id: int) -> str: + + TASK_LOGGER.info(f"Starting new demo task with db id '{db_id}'") + task_data: Optional[ProcessingTask] = ProcessingTask.get_by_id(id_=db_id) + + if task_data is None: + msg = f"Could not load task data with id {db_id} to read parameters!" + TASK_LOGGER.error(msg) + raise KeyError(msg) + + input_params: InputParameters = InputParametersSchema().loads(task_data.parameters) + + qubits = input_params.qubits + depth = input_params.depth + simplify = input_params.simplify + + circuit = zx.generate.cliffordT(qubits, depth) + zx.settings.drawing_backend = 'd3' + fig = zx.draw(circuit) + html = mpld3.fig_to_html(fig) + if simplify: + zx.simplify.full_reduce(circuit) + circuit.normalize() + fig = zx.draw(circuit) + html = mpld3.fig_to_html(fig) + + with SpooledTemporaryFile(mode="wt") as output: + output.write(html) + + STORE.persist_task_result( + db_id, + output, + f"ZXCalculus_Circuit_{qubits}Qubits_{depth}Depth.html", + "circuit", + "text/html", + ) + + return "Result stored in file" diff --git a/stable_plugins/data_synthesis/data_creator/__init__.py b/stable_plugins/data_synthesis/data_creator/__init__.py index a75ad9f37..10f3181b3 100644 --- a/stable_plugins/data_synthesis/data_creator/__init__.py +++ b/stable_plugins/data_synthesis/data_creator/__init__.py @@ -22,7 +22,7 @@ _plugin_name = "data-creator" -__version__ = "v0.2.1" +__version__ = "v0.2.2" _identifier = plugin_identifier(_plugin_name, __version__) diff --git a/stable_plugins/data_synthesis/data_creator/backend/datasets.py b/stable_plugins/data_synthesis/data_creator/backend/datasets.py index bdd820844..4f7aa3784 100644 --- a/stable_plugins/data_synthesis/data_creator/backend/datasets.py +++ b/stable_plugins/data_synthesis/data_creator/backend/datasets.py @@ -14,6 +14,7 @@ from enum import Enum from typing import List, Tuple +from sklearn.datasets import make_blobs import numpy as np @@ -21,6 +22,9 @@ class DataTypeEnum(Enum): two_spirals = "Two Spirals" checkerboard = "Checkerboard" + blobs = "Blobs" + checkerboard_3d = "3D Checkerboard" + blobs_3d = "3D Blobs" def get_data( self, num_train_points: int, num_test_points: int, **kwargs @@ -30,6 +34,12 @@ def get_data( data, labels = twospirals(num_train_points + num_test_points, **kwargs) elif self == DataTypeEnum.checkerboard: data, labels = checkerboard(num_train_points + num_test_points, **kwargs) + elif self == DataTypeEnum.blobs: + data, labels = blobs(num_train_points + num_test_points, **kwargs) + elif self == DataTypeEnum.checkerboard_3d: + data, labels = checkerboard_3d(num_train_points + num_test_points, **kwargs) + elif self == DataTypeEnum.blobs_3d: + data, labels = blobs_3d(num_train_points + num_test_points, **kwargs) else: raise NotImplementedError indices = np.arange(len(data)) @@ -68,3 +78,39 @@ def checkerboard(n_points: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: y += 0.2 if y > 0 else -0.2 rand_points[i] = x, y return rand_points, labels + +def blobs(n_points: int, centers: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Returns the Blobs dataset.""" + x, y = make_blobs( + n_samples=n_points, + centers=centers, + n_features=2, + ) + + return x, y + +def checkerboard_3d(n_points: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Returns the 3D checkerboard dataset.""" + rand_points = (np.random.rand(n_points, 3) * 2) - 1 + labels = np.zeros(n_points).astype(int) + + for i, (x, y, z) in enumerate(rand_points): + # label by quadrant + labels[i] = int(not ((x < 0 and y < 0) or (x >= 0 and y >= 0)) != (z < 0)) + # push away from both axes + x += 0.2 if x > 0 else -0.2 + y += 0.2 if y > 0 else -0.2 + z += 0.2 if z > 0 else -0.2 + rand_points[i] = x, y, z + return rand_points, labels + + +def blobs_3d(n_points: int, centers:int, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Returns the 3D Blobs dataset.""" + x, y = make_blobs( + n_samples=n_points, + centers=centers, + n_features=3, + ) + + return x, y \ No newline at end of file diff --git a/stable_plugins/data_synthesis/data_creator/routes.py b/stable_plugins/data_synthesis/data_creator/routes.py index 5c0cca62d..9a2fa95f4 100644 --- a/stable_plugins/data_synthesis/data_creator/routes.py +++ b/stable_plugins/data_synthesis/data_creator/routes.py @@ -139,6 +139,7 @@ def render(self, data: Mapping, errors: dict): default_values = { fields["noise"].data_key: 0.7, fields["turns"].data_key: 1.52, + fields["centers"].data_key: 4, } # overwrite default values with other values if possible diff --git a/stable_plugins/data_synthesis/data_creator/schemas.py b/stable_plugins/data_synthesis/data_creator/schemas.py index 95c2e6b56..d117086e4 100644 --- a/stable_plugins/data_synthesis/data_creator/schemas.py +++ b/stable_plugins/data_synthesis/data_creator/schemas.py @@ -39,6 +39,7 @@ class InputParameters: num_test_points: int turns: float = None noise: float = None + centers: int = None def __str__(self): return str(self.__dict__.copy()) @@ -93,6 +94,15 @@ class InputParametersSchema(FrontendFormBaseSchema): "input_type": "text", }, ) + centers = ma.fields.Integer( + required=False, + allow_none=False, + metadata={ + "label": "No. Centers", + "description": "Determines the number of Blobs", + "input_type": "number", + }, + ) @post_load def make_input_params(self, data, **kwargs) -> InputParameters: diff --git a/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html b/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html index d6f74fad7..f1b9767ed 100644 --- a/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html +++ b/stable_plugins/data_synthesis/data_creator/templates/data_creator_template.html @@ -4,16 +4,21 @@