From 809763ecd78111572772f737947430ca01d8d717 Mon Sep 17 00:00:00 2001 From: lyna1404 Date: Sat, 31 Aug 2024 14:12:14 +0000 Subject: [PATCH] fixed detectron conditions --- MED3pa/datasets/masked.py | 1 + MED3pa/detectron/experiment.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/MED3pa/datasets/masked.py b/MED3pa/datasets/masked.py index f9382cd..938f00f 100644 --- a/MED3pa/datasets/masked.py +++ b/MED3pa/datasets/masked.py @@ -156,6 +156,7 @@ def sample_random(self, N: int, seed: int) -> 'MaskedDataset': # Set the seed for reproducibility and generate random indices rng = np.random.RandomState(seed) random_indices = rng.permutation(len(self.__observations))[:N] + self.__sample_counts[random_indices] += 1 # Extract the sampled observations and labels sampled_obs = self.__observations[random_indices, :] diff --git a/MED3pa/detectron/experiment.py b/MED3pa/detectron/experiment.py index 23b696e..5e7a282 100644 --- a/MED3pa/detectron/experiment.py +++ b/MED3pa/detectron/experiment.py @@ -248,7 +248,8 @@ def run(datasets: DatasetsManager, 'num_runs': num_calibration_runs, 'patience': patience, 'allow_margin': allow_margin, - 'margin': margin + 'margin': margin, + 'sampling-method':sampling, } experiment_config = { 'experiment_name': "DetectronExperiment",