Skip to content

Commit

Permalink
fiwed an issue regarding the seed of the base model
Browse files Browse the repository at this point in the history
  • Loading branch information
lyna1404 committed Aug 14, 2024
1 parent 3e8c0ca commit fdcb264
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 48 deletions.
8 changes: 5 additions & 3 deletions MED3pa/detectron/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
The ensemble leverages a base model, provided by ``BaseModelManager``, to generate models that are designed to systematically disagree with it in a controlled fashion.
"""
import numpy as np
import copy

from tqdm import tqdm

Expand Down Expand Up @@ -128,10 +129,11 @@ def evaluate_ensemble(self,
model_id = i

# update the training params with the current seed which is the model id
if training_params is not None :
training_params.update({'seed':i})
cdc_training_params = copy.deepcopy(training_params)
if cdc_training_params is not None :
cdc_training_params.update({'seed': i})
else:
training_params = {'seed':i}
cdc_training_params={'seed': i}

# train this cdc to disagree
cdc.train_to_disagree(x_train=training_data.get_observations(), y_train=training_data.get_true_labels(),
Expand Down
24 changes: 2 additions & 22 deletions MED3pa/detectron/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MannWhitneyStrategy(DetectronStrategy):
Implements a strategy to detect disagreement based on the Mann-Whitney U test, assessing the dissimilarity of results
from calibration runs and test runs.
"""
def execute(calibration_records: DetectronRecordsManager, test_records:DetectronRecordsManager, trim_data=True, proportion_to_cut=0.05):
def execute(calibration_records: DetectronRecordsManager, test_records:DetectronRecordsManager, trim_data=True):
"""
Executes the disagreement detection strategy using the Mann-Whitney U test.
Expand All @@ -101,16 +101,6 @@ def execute(calibration_records: DetectronRecordsManager, test_records:Detectron
if len(cal_counts) < 2 or len(test_counts) == 0:
raise ValueError("Not enough records to perform the statistical test.")

def trim_dataset(data, proportion_to_cut):
if not 0 <= proportion_to_cut < 0.5:
raise ValueError("proportion_to_cut must be between 0 and 0.5")

data_sorted = np.sort(data)
n = len(data)
trim_count = int(n * proportion_to_cut)

return data_sorted[trim_count:n - trim_count]

def remove_outliers_based_on_iqr(arr1, arr2):
# Calculate Q1 (25th percentile) and Q3 (75th percentile)
Q1 = np.percentile(arr1, 25)
Expand Down Expand Up @@ -248,7 +238,7 @@ class EnhancedDisagreementStrategy(DetectronStrategy):
Implements a strategy to detect disagreement based on the z-score mean difference between calibration and test datasets.
This strategy calculates the probability of a shift based on the counts where test rejected counts are compared to calibration rejected counts.
"""
def execute(calibration_records: DetectronRecordsManager, test_records: DetectronRecordsManager, trim_data=True, proportion_to_cut=0.05):
def execute(calibration_records: DetectronRecordsManager, test_records: DetectronRecordsManager, trim_data=True):
"""
Executes the disagreement detection strategy using z-score analysis.
Expand All @@ -269,16 +259,6 @@ def execute(calibration_records: DetectronRecordsManager, test_records: Detectro
if len(cal_counts) < 2 or len(test_counts) == 0:
raise ValueError("Not enough records to perform the statistical test.")

def trim_dataset(data, proportion_to_cut):
if not 0 <= proportion_to_cut < 0.5:
raise ValueError("proportion_to_cut must be between 0 and 0.5")

data_sorted = np.sort(data)
n = len(data)
trim_count = int(n * proportion_to_cut)

return data_sorted[trim_count:n - trim_count]

def remove_outliers_based_on_iqr(arr1, arr2):
# Calculate Q1 (25th percentile) and Q3 (75th percentile)
Q1 = np.percentile(arr1, 25)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="MED3pa",
version="0.1.23",
version="0.1.24",
author="MEDomics consortium",
author_email="medomics.info@gmail.com",
description="Python Open-source package for ensuring robust and reliable ML models deployments",
Expand Down
44 changes: 22 additions & 22 deletions tutorials/detectron_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:13<00:00, 7.42it/s]\n"
"running seeds: 100%|██████████| 100/100 [00:07<00:00, 13.12it/s]\n"
]
},
{
Expand All @@ -108,7 +108,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:12<00:00, 7.74it/s]\n"
"running seeds: 100%|██████████| 100/100 [00:07<00:00, 13.27it/s]\n"
]
},
{
Expand All @@ -122,7 +122,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:13<00:00, 7.62it/s]\n"
"running seeds: 100%|██████████| 100/100 [00:06<00:00, 14.34it/s]\n"
]
},
{
Expand All @@ -136,7 +136,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:14<00:00, 6.82it/s]"
"running seeds: 100%|██████████| 100/100 [00:07<00:00, 12.91it/s]"
]
},
{
Expand Down Expand Up @@ -178,26 +178,26 @@
{
"data": {
"text/plain": [
"[{'shift_probability': 0.9444444444444444,\n",
" 'test_statistic': 9.722222222222221,\n",
" 'baseline_mean': 7.4,\n",
" 'baseline_std': 1.2631530214330944,\n",
" 'significance_description': {'unsignificant shift': 15.555555555555555,\n",
" 'small': 8.88888888888889,\n",
" 'moderate': 13.333333333333334,\n",
" 'large': 62.22222222222222},\n",
"[{'shift_probability': 0.7474747474747475,\n",
" 'test_statistic': 5.111111111111111,\n",
" 'baseline_mean': 3.808080808080808,\n",
" 'baseline_std': 1.315615366073934,\n",
" 'significance_description': {'unsignificant shift': 25.252525252525253,\n",
" 'small': 30.303030303030305,\n",
" 'moderate': 18.181818181818183,\n",
" 'large': 26.262626262626267},\n",
" 'Strategy': 'enhanced_disagreement_strategy'},\n",
" {'p_value': 4.6393173988598416e-17,\n",
" 'u_statistic': 1182.5,\n",
" 'significance_description': {'unsignificant shift': 15.555555555555555,\n",
" 'small': 8.88888888888889,\n",
" 'moderate': 13.333333333333334,\n",
" 'large': 62.22222222222222},\n",
" {'p_value': 9.063705352705725e-07,\n",
" 'u_statistic': 3006.0,\n",
" 'significance_description': {'unsignificant shift': 25.252525252525253,\n",
" 'small': 30.303030303030305,\n",
" 'moderate': 18.181818181818183,\n",
" 'large': 26.262626262626267},\n",
" 'Strategy': 'mannwhitney_strategy'},\n",
" {'p_value': 0.05,\n",
" 'test_statistic': 8,\n",
" 'baseline_mean': 12.47,\n",
" 'baseline_std': 1.920702996301094,\n",
" {'p_value': 0.65,\n",
" 'test_statistic': 16,\n",
" 'baseline_mean': 16.14,\n",
" 'baseline_std': 1.4072668545801823,\n",
" 'Strategy': 'original_disagreement_strategy'}]"
]
},
Expand Down

0 comments on commit fdcb264

Please sign in to comment.