Skip to content

Commit

Permalink
fixed an error in tree builing with pretrained apc + ordered rejectio…
Browse files Browse the repository at this point in the history
…n counts in detectron
  • Loading branch information
lyna1404 committed Aug 12, 2024
1 parent 849b3fd commit 2566d20
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 256 deletions.
4 changes: 2 additions & 2 deletions MED3pa/detectron/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def save(self, file_path: str, file_name: str = 'detectron_results', save_config
json.dump(self.test_results, file, indent=4)

counts_dict = {}
counts_dict['reference'] = self.cal_record.rejected_counts().tolist()
counts_dict['test'] = self.test_record.rejected_counts().tolist()
counts_dict['reference'] = np.sort(self.cal_record.rejected_counts()).tolist()
counts_dict['test'] = np.sort(self.test_record.rejected_counts()).tolist()

file_name_path_counts = os.path.join(file_path, 'rejection_counts.json')
with open(file_name_path_counts, 'w') as file:
Expand Down
10 changes: 9 additions & 1 deletion MED3pa/med3pa/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,23 @@ def build_tree(self, dtr: DecisionTreeRegressorModel, X: DataFrame, y: Series, n
curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio,
node_id=self.nb_nodes, path=path)
return curr_node

node_thresh = dtr.model.tree_.threshold[node_id]
node_feature_id = dtr.model.tree_.feature[node_id]
node_feature = self.features[node_feature_id]

# Check if the split would result in an empty set, if so, stop the recursion
if y[X[node_feature] <= node_thresh].size == 0 or y[X[node_feature] > node_thresh].size == 0:
print("split would results in an empty data section")
curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio,
node_id=self.nb_nodes, path=path)
return curr_node

curr_path = list(path) # Copy the current path to avoid modifying the original list
curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio,
threshold=node_thresh, feature=node_feature, feature_id=node_feature_id,
node_id=self.nb_nodes, path=curr_path)


# Update paths for child nodes
left_path = curr_path + [f"{node_feature} <= {node_thresh}"]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="MED3pa",
version="0.1.20",
version="0.1.21",
author="MEDomics consortium",
author_email="medomics.info@gmail.com",
description="Python Open-source package for ensuring robust and reliable ML models deployments",
Expand Down
106 changes: 8 additions & 98 deletions tutorials/detectron_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -62,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -87,73 +87,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:13<00:00, 7.57it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on reference set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:14<00:00, 7.01it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on testing set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:24<00:00, 4.15it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on reference set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"running seeds: 100%|██████████| 100/100 [00:14<00:00, 6.89it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detectron execution on testing set completed.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"outputs": [],
"source": [
"from MED3pa.detectron import DetectronExperiment\n",
"\n",
Expand All @@ -172,35 +108,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'shift_probability': 0.5777777777777777,\n",
" 'test_statistic': 7.6,\n",
" 'baseline_mean': 7.4,\n",
" 'baseline_std': 1.2631530214330944,\n",
" 'significance_description': {'unsignificant shift': 56.148148148148145,\n",
" 'small': 18.432098765432098,\n",
" 'moderate': 12.506172839506172,\n",
" 'large': 12.91358024691358},\n",
" 'Strategy': 'enhanced_disagreement_strategy'},\n",
" {'p_value': 0.22518291651155886,\n",
" 'u_statistic': 4696.5,\n",
" 'significance_description': {'unsignificant shift': 55.33,\n",
" 'small': 15.459999999999999,\n",
" 'moderate': 19.56,\n",
" 'large': 9.65},\n",
" 'Strategy': 'mannwhitney_strategy'}]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from MED3pa.detectron.strategies import *\n",
"\n",
Expand All @@ -220,7 +130,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -238,7 +148,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down
Loading

0 comments on commit 2566d20

Please sign in to comment.