diff --git a/MED3pa/med3pa/comparaison.py b/MED3pa/med3pa/comparaison.py index 0f318c7..a5d5748 100644 --- a/MED3pa/med3pa/comparaison.py +++ b/MED3pa/med3pa/comparaison.py @@ -27,6 +27,7 @@ def __init__(self, results1_path: str, results2_path: str) -> None: self.config_file = {} self.compare_profiles = False self.compare_detectron = False + self.mode = "" self._check_experiment_name() def _check_experiment_name(self) -> None: @@ -79,6 +80,12 @@ def is_comparable(self) -> bool: elif base_model_different and not datasets_different and not params_different: can_compare = True + if can_compare: + if self.compare_detectron: + self.mode = self.config_file['med3pa_detectron_params']['med3pa_detectron_params1']['med3pa_params']['mode'] + else: + self.mode = self.config_file['med3pa_params']['med3pa_params1']['mode'] + return can_compare def _check_experiment_tree(self) -> None: @@ -281,11 +288,13 @@ def compare_experiments(self): raise ValueError("The two experiments cannot be compared based on the provided criteria.") self.compare_global_metrics() - self._check_experiment_tree() - if self.compare_profiles: - self.compare_profiles_metrics() - if self.compare_detectron: - self.compare_profiles_detectron_results() + + if self.mode in ['apc', 'mpc']: + self._check_experiment_tree() + if self.compare_profiles: + self.compare_profiles_metrics() + if self.compare_detectron: + self.compare_profiles_detectron_results() self.compare_config() self.compare_models_evaluation() diff --git a/tutorials/med3pa_tutorials.ipynb b/tutorials/med3pa_tutorials.ipynb index ca0b447..a179bdf 100644 --- a/tutorials/med3pa_tutorials.ipynb +++ b/tutorials/med3pa_tutorials.ipynb @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -95,9 +95,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running MED3pa Experiment on the reference set:\n", + "IPC Model training complete.\n", + "IPC Model optimization complete.\n", + "Individualized confidence scores calculated.\n", + "Running MED3pa Experiment on the test set:\n", + "IPC Model training complete.\n", + "IPC Model optimization complete.\n", + "Individualized confidence scores calculated.\n", + "Running MED3pa Experiment on the reference set:\n", + "IPC Model training complete.\n", + "IPC Model optimization complete.\n", + "Individualized confidence scores calculated.\n", + "Running MED3pa Experiment on the test set:\n", + "IPC Model training complete.\n", + "IPC Model optimization complete.\n", + "Individualized confidence scores calculated.\n" + ] + } + ], "source": [ "from MED3pa.med3pa import Med3paExperiment\n", "\n", @@ -137,6 +160,7 @@ " samples_ratio_step=5,\n", " med3pa_metrics=med3pa_metrics,\n", " evaluate_models=True,\n", + " mode='ipc',\n", " )\n", "\n", "BaseModelManager.reset()\n", @@ -162,6 +186,7 @@ " samples_ratio_step=5,\n", " med3pa_metrics=med3pa_metrics,\n", " evaluate_models=True,\n", + " mode='ipc',\n", " )" ] }, @@ -175,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -193,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -278,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -289,7 +314,7 @@ "\n", "from MED3pa.med3pa.comparaison import Med3paComparison\n", "\n", - "comparaison = Med3paComparison('./med3pa_experiment_results_pretrained', './med3pa_experiment_results_2_pretrained')\n", + "comparaison = Med3paComparison('./med3pa_experiment_results', './med3pa_experiment_results_2')\n", "comparaison.compare_experiments()\n", "comparaison.save('./med3pa_comparaison_results')" ]