From 5f2a5a59ff33b612e3f6649527e5b98d5bb161ce Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 11 Sep 2023 10:02:07 +0000 Subject: [PATCH] New database widget refactoring and for new resource registry --- aiidalab_widgets_base/databases.py | 193 ++++++++++++++++++++++++ aiidalab_widgets_base/utils/__init__.py | 1 - notebooks/computational_resources.ipynb | 19 +++ tests/test_databases.py | 35 ++++- 4 files changed, 246 insertions(+), 2 deletions(-) diff --git a/aiidalab_widgets_base/databases.py b/aiidalab_widgets_base/databases.py index 2302e737f..13956861f 100644 --- a/aiidalab_widgets_base/databases.py +++ b/aiidalab_widgets_base/databases.py @@ -380,3 +380,196 @@ def _code_changed(self, _=None): @tl.default("default_calc_job_plugin") def _default_calc_job_plugin(self): return None + + +class NewComputationalResourcesDatabaseWidget(ipw.VBox): + """Extract the setup of a known computer from the AiiDA code registry.""" + + _default_database_source = ( + "https://unkcpz.github.io/aiida-resource-registry/database.json" + ) + + database_source = tl.Unicode(allow_none=True) + + ssh_config = tl.Dict() + computer_setup = tl.Dict() + code_setup = tl.Dict() + + def __init__(self, default_calc_job_plugin=None, database_source=None, **kwargs): + if database_source is None: + database_source = self._default_database_source + + self.default_calc_job_plugin = default_calc_job_plugin + + # Select domain. + self.domain_selector = ipw.Dropdown( + options=[], + description="Domain", + disabled=False, + ) + self.domain_selector.observe(self._domain_changed, names=["value", "options"]) + + # Select computer. + self.computer_selector = ipw.Dropdown( + options=[], + description="Computer:", + disabled=False, + ) + self.computer_selector.observe( + self._computer_changed, names=["value", "options"] + ) + + # Select code. + self.code_selector = ipw.Dropdown( + options=[], + description="Code:", + disabled=False, + ) + self.code_selector.observe(self._code_changed, names=["value", "options"]) + + reset_button = ipw.Button(description="Reset") + reset_button.on_click(self.reset) + + super().__init__( + children=[ + self.domain_selector, + self.computer_selector, + self.code_selector, + reset_button, + ], + **kwargs, + ) + self.database_source = database_source + self.reset() + + def reset(self, _=None): + """Reset widget and traits""" + with self.hold_trait_notifications(): + self.domain_selector.value = None + self.computer_selector.value = None + self.code_selector.value = None + + @tl.observe("database_source") + def _database_source_changed(self, _=None): + self.database = self._database_generator( + self.database_source, self.default_calc_job_plugin + ) + + # Update domain selector. + self.domain_selector.options = self.database.keys() + self.reset() + + def _database_generator(self, database_source, default_calc_job_plugin): + """From database source JSON and default calc job plugin, generate resource database""" + try: + database = requests.get(database_source).json() + except Exception: + database = {} + + if default_calc_job_plugin is None: + return database + + # filter database by default calc job plugin + for domain, domain_value in database.items(): + for computer, computer_value in domain_value.items(): + if computer == "default": + # skip default computer + continue + + for code, code_value in list(computer_value["codes"].items()): + if code_value["default_calc_job_plugin"] != default_calc_job_plugin: + # remove code + del computer_value["codes"][code] + + if len(computer_value["codes"]) == 0: + # remove computer since no codes defined in this computer source + del domain_value[computer] + + if len(domain_value) == 0: + # remove domain since no computers with required codes defined in this domain source + del database[domain] + continue + + if domain_value["default"] not in domain_value: + # make sure default computer is still points to existing computer + domain_value["default"] = sorted(domain_value.keys() - {"default"})[0] + + return database + + def _domain_changed(self, change=None): + """callback when new domain selected""" + if change["new"] is None: + return + else: + selected_domain = self.domain_selector.value + + with self.hold_trait_notifications(): + try: + self.computer_selector.options = [ + key + for key in self.database[selected_domain].keys() + if key != "default" + ] + self.computer_selector.value = self.database[selected_domain]["default"] + except KeyError: + raise + + def _computer_changed(self, change=None): + """callback when new computer selected""" + if change["new"] is None: + self.computer_setup = {} + self.ssh_config = {} + return + else: + selected_computer = self.computer_selector.value + + selected_domain = self.domain_selector.value + + with self.hold_trait_notifications(): + computer_dict = self.database.get(selected_domain, {}).get( + selected_computer, {} + ) + + try: + self.code_selector.options = list(computer_dict.get("codes", {}).keys()) + except KeyError: + raise + + computer_setup = computer_dict.get("computer", {}).get("computer-setup", {}) + computer_configure = computer_dict.get("computer", {}).get( + "computer-configure", {} + ) + + ssh_config = {"hostname": computer_setup["hostname"]} + if "proxy_command" in computer_configure: + ssh_config["proxy_command"] = computer_configure["proxy_command"] + if "proxy_jump" in computer_configure: + ssh_config["proxy_jump"] = computer_configure["proxy_jump"] + + self.ssh_config = ssh_config # To notify the trait change + + self.computer_setup = { + "setup": computer_setup, + "configure": computer_configure, + } + + def _code_changed(self, change=None): + """Update code settings.""" + if change["new"] is None: + self.code_setup = {} + return + else: + selected_code = self.code_selector.value + + selected_domain = self.domain_selector.value + selected_computer = self.computer_selector.value + + try: + self.code_setup = ( + self.database.get(selected_domain, {}) + .get(selected_computer, {}) + .get("codes", {}) + .get(selected_code) + ) + except KeyError: + raise diff --git a/aiidalab_widgets_base/utils/__init__.py b/aiidalab_widgets_base/utils/__init__.py index 042059b7b..69d518986 100644 --- a/aiidalab_widgets_base/utils/__init__.py +++ b/aiidalab_widgets_base/utils/__init__.py @@ -169,7 +169,6 @@ class StatusHTML(_StatusWidgetMixin, ipw.HTML): # for an unknown reason. @traitlets.observe("message") def _observe_message(self, change): - print("!!!", change["new"]) self.show_temporary_message(change["new"]) diff --git a/notebooks/computational_resources.ipynb b/notebooks/computational_resources.ipynb index 418827408..2f6c0eeba 100644 --- a/notebooks/computational_resources.ipynb +++ b/notebooks/computational_resources.ipynb @@ -56,6 +56,25 @@ "display(resources1)\n", "display(resources2)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a302c5f", + "metadata": {}, + "outputs": [], + "source": [ + "db = awb.databases.NewComputationalResourcesDatabaseWidget()\n", + "display(db)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdcce7b2", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tests/test_databases.py b/tests/test_databases.py index b0f80fe13..1dd044029 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -32,7 +32,7 @@ def test_optimade_query_widget(): assert widget.structure is None -def test_computational_resources_database_widget(): +def test_legacy_computational_resources_database_widget(): """Test the structure browser widget.""" from aiidalab_widgets_base.databases import ComputationalResourcesDatabaseWidget @@ -62,3 +62,36 @@ def test_computational_resources_database_widget(): assert widget.computer_setup == {} assert widget.code_setup == {} assert widget.ssh_config == {} + + +def test_computational_resources_database_widget(): + """Test the structure browser widget.""" + from aiidalab_widgets_base.databases import NewComputationalResourcesDatabaseWidget + + # Initiate the widget with no arguments. + widget = NewComputationalResourcesDatabaseWidget() + assert "daint.cscs.ch" in widget.database + + # Initialize the widget with default_calc_job_plugin="cp2k" + widget = NewComputationalResourcesDatabaseWidget(default_calc_job_plugin="cp2k") + assert ( + "merlin.psi.ch" not in widget.database + ) # Merlin does not have CP2K installed. + + # Select computer/code + widget.domain_selector.value = "daint.cscs.ch" + widget.computer_selector.value = "mc" + widget.code_selector.value = "cp2k-9.1" + + # Check that the configuration is provided. + + assert "label" in widget.computer_setup["setup"] + assert "hostname" in widget.ssh_config + assert "filepath_executable" in widget.code_setup + + # Simulate reset. + widget.reset() + + assert widget.computer_setup == {} + assert widget.code_setup == {} + assert widget.ssh_config == {}