Skip to content

Commit

Permalink
fixed med3pa comparison case of ipc model
Browse files Browse the repository at this point in the history
  • Loading branch information
lyna1404 committed Aug 14, 2024
1 parent 99f9de6 commit bf1da2f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
19 changes: 14 additions & 5 deletions MED3pa/med3pa/comparaison.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, results1_path: str, results2_path: str) -> None:
self.config_file = {}
self.compare_profiles = False
self.compare_detectron = False
self.mode = ""
self._check_experiment_name()

def _check_experiment_name(self) -> None:
Expand Down Expand Up @@ -79,6 +80,12 @@ def is_comparable(self) -> bool:
elif base_model_different and not datasets_different and not params_different:
can_compare = True

if can_compare:
if self.compare_detectron:
self.mode = self.config_file['med3pa_detectron_params']['med3pa_detectron_params1']['med3pa_params']['mode']
else:
self.mode = self.config_file['med3pa_params']['med3pa_params1']['mode']

return can_compare

def _check_experiment_tree(self) -> None:
Expand Down Expand Up @@ -281,11 +288,13 @@ def compare_experiments(self):
raise ValueError("The two experiments cannot be compared based on the provided criteria.")

self.compare_global_metrics()
self._check_experiment_tree()
if self.compare_profiles:
self.compare_profiles_metrics()
if self.compare_detectron:
self.compare_profiles_detectron_results()

if self.mode in ['apc', 'mpc']:
self._check_experiment_tree()
if self.compare_profiles:
self.compare_profiles_metrics()
if self.compare_detectron:
self.compare_profiles_detectron_results()

self.compare_config()
self.compare_models_evaluation()
Expand Down
41 changes: 33 additions & 8 deletions tutorials/med3pa_tutorials.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -68,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -95,9 +95,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running MED3pa Experiment on the reference set:\n",
"IPC Model training complete.\n",
"IPC Model optimization complete.\n",
"Individualized confidence scores calculated.\n",
"Running MED3pa Experiment on the test set:\n",
"IPC Model training complete.\n",
"IPC Model optimization complete.\n",
"Individualized confidence scores calculated.\n",
"Running MED3pa Experiment on the reference set:\n",
"IPC Model training complete.\n",
"IPC Model optimization complete.\n",
"Individualized confidence scores calculated.\n",
"Running MED3pa Experiment on the test set:\n",
"IPC Model training complete.\n",
"IPC Model optimization complete.\n",
"Individualized confidence scores calculated.\n"
]
}
],
"source": [
"from MED3pa.med3pa import Med3paExperiment\n",
"\n",
Expand Down Expand Up @@ -137,6 +160,7 @@
" samples_ratio_step=5,\n",
" med3pa_metrics=med3pa_metrics,\n",
" evaluate_models=True,\n",
" mode='ipc',\n",
" )\n",
"\n",
"BaseModelManager.reset()\n",
Expand All @@ -162,6 +186,7 @@
" samples_ratio_step=5,\n",
" med3pa_metrics=med3pa_metrics,\n",
" evaluate_models=True,\n",
" mode='ipc',\n",
" )"
]
},
Expand All @@ -175,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -193,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -278,7 +303,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -289,7 +314,7 @@
"\n",
"from MED3pa.med3pa.comparaison import Med3paComparison\n",
"\n",
"comparaison = Med3paComparison('./med3pa_experiment_results_pretrained', './med3pa_experiment_results_2_pretrained')\n",
"comparaison = Med3paComparison('./med3pa_experiment_results', './med3pa_experiment_results_2')\n",
"comparaison.compare_experiments()\n",
"comparaison.save('./med3pa_comparaison_results')"
]
Expand Down

0 comments on commit bf1da2f

Please sign in to comment.