Skip to content

Commit

Permalink
generalized strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
lyna1404 committed Aug 12, 2024
1 parent 2566d20 commit 19cdadd
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 24 deletions.
42 changes: 29 additions & 13 deletions MED3pa/detectron/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MannWhitneyStrategy(DetectronStrategy):
Implements a strategy to detect disagreement based on the Mann-Whitney U test, assessing the dissimilarity of results
from calibration runs and test runs.
"""
def execute(calibration_records: DetectronRecordsManager, test_records:DetectronRecordsManager):
def execute(calibration_records: DetectronRecordsManager, test_records:DetectronRecordsManager, trim_data=True, proportion_to_cut=0.05):
"""
Executes the disagreement detection strategy using the Mann-Whitney U test.
Expand All @@ -97,23 +97,41 @@ def execute(calibration_records: DetectronRecordsManager, test_records:Detectron
cal_counts = calibration_records.rejected_counts()
test_counts = test_records.rejected_counts()

cal_mean = np.mean(cal_counts)
cal_std = np.std(cal_counts)
test_mean = np.mean(test_counts)
# Ensure there are enough records to perform bootstrap
if len(cal_counts) < 2 or len(test_counts) == 0:
raise ValueError("Not enough records to perform the statistical test.")

def trim_dataset(data, proportion_to_cut):
if not 0 <= proportion_to_cut < 0.5:
raise ValueError("proportion_to_cut must be between 0 and 0.5")

data_sorted = np.sort(data)
n = len(data)
trim_count = int(n * proportion_to_cut)

return data_sorted[trim_count:n - trim_count]

if trim_data:
# Trim calibration and test data if trimming is enabled
cal_counts = trim_dataset(cal_counts, proportion_to_cut)
test_counts = trim_dataset(test_counts, proportion_to_cut)

baseline_mean = np.mean(cal_counts)
baseline_std = np.std(cal_counts)

# Perform the Mann-Whitney U test
u_statistic, p_value = stats.mannwhitneyu(cal_counts, test_counts, alternative='less')

# Calculate the z-scores for the test data
z_scores = (test_counts[:, None] - cal_counts) / np.std(cal_counts)
z_scores = (test_counts - baseline_mean) / baseline_std

# Define thresholds for categorizing
def categorize_z_score(z):
if z <= 0:
return 'no significant shift'
elif abs(z) < 1:
elif 0 < z <= 1:
return 'small'
elif abs(z) < 2:
elif 1 < z <= 2:
return 'moderate'
else:
return 'large'
Expand Down Expand Up @@ -238,29 +256,27 @@ def trim_dataset(data, proportion_to_cut):

# Calculate the baseline mean and standard deviation on trimmed or full data
baseline_mean = np.mean(cal_counts)
test_mean = np.mean(test_counts)
baseline_std = np.std(cal_counts)
test_std = np.std(test_counts)

# Calculate the test statistic (mean of test data)
test_statistic = np.mean(test_counts)

# Calculate the z-scores for the test data
z_scores = (test_counts[:, None] - cal_counts) / np.std(cal_counts)
z_scores = (test_counts - baseline_mean) / baseline_std

# Define thresholds for categorizing
def categorize_z_score(z):
if z <= 0:
return 'no significant shift'
elif abs(z) < 1:
elif 0 < z <= 1:
return 'small'
elif abs(z) < 2:
elif 1 < z <= 2:
return 'moderate'
else:
return 'large'

# Categorize each test count based on its z-score
categories = np.array([categorize_z_score(z) for z in z_scores.flatten()])
categories = np.array([categorize_z_score(z) for z in z_scores])
# Calculate the percentage of each category
category_counts = pd.Series(categories).value_counts(normalize=True) * 100

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="MED3pa",
version="0.1.21",
version="0.1.22",
author="MEDomics consortium",
author_email="medomics.info@gmail.com",
description="Python Open-source package for ensuring robust and reliable ML models deployments",
Expand Down
115 changes: 105 additions & 10 deletions tutorials/detectron_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +48,7 @@
"datasets2.set_from_file(dataset_type=\"training\", file='./data/train_data.csv', target_column_name='Outcome')\n",
"datasets2.set_from_file(dataset_type=\"validation\", file='./data/val_data.csv', target_column_name='Outcome')\n",
"datasets2.set_from_file(dataset_type=\"reference\", file='./data/test_data.csv', target_column_name='Outcome')\n",
"datasets2.set_from_file(dataset_type=\"testing\", file='./data/test_data_shifted_0.1.csv', target_column_name='Outcome')\n",
"datasets2.set_from_file(dataset_type=\"testing\", file='./data/test_data_shifted_1.6.csv', target_column_name='Outcome')\n",
"\n"
]
},
Expand All @@ -62,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -87,9 +87,73 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:13<00:00, 7.42it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on reference set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:12<00:00, 7.74it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on testing set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:13<00:00, 7.62it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on reference set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:14<00:00, 6.82it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on testing set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from MED3pa.detectron import DetectronExperiment\n",
"\n",
Expand All @@ -108,14 +172,45 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"[{'shift_probability': 0.9444444444444444,\n",
" 'test_statistic': 9.722222222222221,\n",
" 'baseline_mean': 7.4,\n",
" 'baseline_std': 1.2631530214330944,\n",
" 'significance_description': {'unsignificant shift': 15.555555555555555,\n",
" 'small': 8.88888888888889,\n",
" 'moderate': 13.333333333333334,\n",
" 'large': 62.22222222222222},\n",
" 'Strategy': 'enhanced_disagreement_strategy'},\n",
" {'p_value': 4.6393173988598416e-17,\n",
" 'u_statistic': 1182.5,\n",
" 'significance_description': {'unsignificant shift': 15.555555555555555,\n",
" 'small': 8.88888888888889,\n",
" 'moderate': 13.333333333333334,\n",
" 'large': 62.22222222222222},\n",
" 'Strategy': 'mannwhitney_strategy'},\n",
" {'p_value': 0.05,\n",
" 'test_statistic': 8,\n",
" 'baseline_mean': 12.47,\n",
" 'baseline_std': 1.920702996301094,\n",
" 'Strategy': 'original_disagreement_strategy'}]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from MED3pa.detectron.strategies import *\n",
"\n",
"# Analyze the results using the disagreement strategies\n",
"test_strategies = [\"enhanced_disagreement_strategy\", \"mannwhitney_strategy\"]\n",
"test_strategies = [\"enhanced_disagreement_strategy\", \"mannwhitney_strategy\", \"original_disagreement_strategy\"]\n",
"experiment_results.analyze_results(test_strategies)\n",
"experiment_results2.analyze_results(test_strategies)\n"
]
Expand All @@ -130,7 +225,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -148,7 +243,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit 19cdadd

Please sign in to comment.