From 1d9f67cfc9804534d2712a67d5a36e2e58b8787e Mon Sep 17 00:00:00 2001 From: AWbosman Date: Thu, 12 Sep 2024 17:32:02 +0200 Subject: [PATCH 1/3] updated binsearch, such that investigated quarters --- ...d_binary_search_epsilon_value_estimator.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py diff --git a/robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py b/robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py new file mode 100644 index 0000000..1649c79 --- /dev/null +++ b/robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py @@ -0,0 +1,75 @@ +import logging + +import time +logger = logging.getLogger(__name__) +import torch +from robustness_experiment_box.verification_module.verification_module import VerificationModule +from robustness_experiment_box.epsilon_value_estimator.epsilon_value_estimator import BinarySearchEpsilonValueEstimator +from robustness_experiment_box.database.verification_context import VerificationContext +from robustness_experiment_box.database.verification_result import VerificationResult +from robustness_experiment_box.database.epsilon_status import EpsilonStatus + + +class AdjustedBinarySearchEpsilonValueEstimator(BinarySearchEpsilonValueEstimator): + + + def get_next_epsilon(self, midpoint:int, first:int, last:int, epsilon_status_list: list[EpsilonStatus]) -> int: + random_number = torch.rand(1).item() + if random_number<= 0.5: + return (first + midpoint) // 2 + else: + return (midpoint + last) // 2 + + def binary_search(self, verification_context: VerificationContext, epsilon_status_list: list[EpsilonStatus]) -> float: + + if len(epsilon_status_list) == 1: + outcome = self.verifier.verify(verification_context, epsilon_status_list[0].value) + result = outcome.result + epsilon_status_list[0].time = outcome.took + epsilon_status_list[0].result = result + logger.debug(f'current epsilon value: {epsilon_status_list[0].result}, took: {epsilon_status_list[0].time}') + verification_context.save_result(epsilon_status_list[0]) + if result == VerificationResult.UNSAT: + return epsilon_status_list[0].value, self.get_smallest_sat(epsilon_status_list) + else: + return 0, self.get_smallest_sat(epsilon_status_list) + + first = 0 + last = len(epsilon_status_list) - 1 + midpoint = (first + last) // 2 + + while first<=last: + + if not epsilon_status_list[midpoint].result: + + outcome = self.verifier.verify(verification_context, epsilon_status_list[midpoint].value) + epsilon_status_list[midpoint].result = outcome.result + epsilon_status_list[midpoint].time = outcome.took + verification_context.save_result(epsilon_status_list[midpoint]) + logger.debug(f'current epsilon value: {epsilon_status_list[midpoint].result}, took: {epsilon_status_list[midpoint].time}') + + if epsilon_status_list[midpoint].result == VerificationResult.UNSAT: + first = midpoint + 1 + midpoint = (first + last) // 2 + elif epsilon_status_list[midpoint].result == VerificationResult.SAT: + last = midpoint - 1 + midpoint = (first + last) // 2 + else: + if len(epsilon_status_list)>3: + midpoint = self.get_next_epsilon(epsilon_status_list) + epsilon_status_list.pop(midpoint) + last = last - 1 + else: + epsilon_status_list.pop(midpoint) + last = last - 1 + midpoint = (first + last) // 2 + + + + logger.debug(f"epsilon status list: {[(x.value, x.result, x.time) for x in epsilon_status_list]}") + + highest_unsat_value = self.get_highest_unsat(epsilon_status_list) + + smallest_sat_value = self.get_smallest_sat(epsilon_status_list) + + return highest_unsat_value, smallest_sat_value From fb2dfe5eb5a3780a97154de16cd4fadf3576b1ec Mon Sep 17 00:00:00 2001 From: AWbosman Date: Fri, 13 Sep 2024 13:46:29 +0200 Subject: [PATCH 2/3] updated before testing --- ..._binary_search_epsilon_value_estimator.py} | 14 +++-- ...d_binary_search_epsilon_value_estimator.py | 63 +++++++++++++++++++ 2 files changed, 72 insertions(+), 5 deletions(-) rename robustness_experiment_box/epsilon_value_estimator/{adjusted_binary_search_epsilon_value_estimator.py => quartered_binary_search_epsilon_value_estimator.py} (88%) create mode 100644 tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py diff --git a/robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py b/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py similarity index 88% rename from robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py rename to robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py index 1649c79..885b3b9 100644 --- a/robustness_experiment_box/epsilon_value_estimator/adjusted_binary_search_epsilon_value_estimator.py +++ b/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py @@ -10,15 +10,19 @@ from robustness_experiment_box.database.epsilon_status import EpsilonStatus -class AdjustedBinarySearchEpsilonValueEstimator(BinarySearchEpsilonValueEstimator): +class QuarteredBinarySearchEpsilonValueEstimator(BinarySearchEpsilonValueEstimator): - def get_next_epsilon(self, midpoint:int, first:int, last:int, epsilon_status_list: list[EpsilonStatus]) -> int: + def get_next_epsilon(self, midpoint:int, first:int, last:int) -> int: random_number = torch.rand(1).item() if random_number<= 0.5: - return (first + midpoint) // 2 + next = (first + midpoint) // 2 + if next == midpoint: + return first else: - return (midpoint + last) // 2 + next = (midpoint + last) // 2 + if next == midpoint: + return last def binary_search(self, verification_context: VerificationContext, epsilon_status_list: list[EpsilonStatus]) -> float: @@ -56,7 +60,7 @@ def binary_search(self, verification_context: VerificationContext, epsilon_statu midpoint = (first + last) // 2 else: if len(epsilon_status_list)>3: - midpoint = self.get_next_epsilon(epsilon_status_list) + midpoint = self.get_next_epsilon(midpoint=midpoint, first=first,last=last) epsilon_status_list.pop(midpoint) last = last - 1 else: diff --git a/tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py b/tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py new file mode 100644 index 0000000..fec9d63 --- /dev/null +++ b/tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py @@ -0,0 +1,63 @@ +import pytest + +from robustness_experiment_box.database.verification_result import VerificationResult +from robustness_experiment_box.database.epsilon_value_result import EpsilonValueResult +from robustness_experiment_box.epsilon_value_estimator.binary_search_epsilon_value_estimator import QuarteredBinarySearchEpsilonValueEstimator + +from tests.test_epsilon_value_estimator.conftest import MockVerificationModule + + +#TODO: adjust this for quartered +class TestBinarySearchEpsilonValueEstimator: + + + def test_verifier_gets_called(self, mocker, verification_context): + verification_module = MockVerificationModule(None) + verifier = mocker.Mock(verification_module) + estimator = QuarteredBinarySearchEpsilonValueEstimator([0.1], verifier=verifier) + + epsilon_value_result = estimator.compute_epsilon_value(verification_context) + + verifier.verify.assert_called() + + def test_result_class_returned(self, verification_context): + verifier = MockVerificationModule({0.1 : VerificationResult.SAT}) + estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=[0.1], verifier=verifier) + + epsilon_value_result = estimator.compute_epsilon_value(verification_context) + + assert isinstance(epsilon_value_result, EpsilonValueResult) + + @pytest.mark.parametrize("epsilon_verification_dict, expected_result", [ + ({0.1 : VerificationResult.ERROR, 0.2 : VerificationResult.ERROR, 0.3: VerificationResult.ERROR}, 0.), + ({0.1 : VerificationResult.TIMEOUT, 0.2 : VerificationResult.TIMEOUT, 0.3: VerificationResult.TIMEOUT}, 0.), + ({0.1 : VerificationResult.SAT, 0.2 : VerificationResult.SAT, 0.3: VerificationResult.SAT}, 0), + ({0.1 : VerificationResult.UNSAT, 0.2 : VerificationResult.SAT, 0.3: VerificationResult.SAT}, 0.1), + ({0.1 : VerificationResult.UNSAT, 0.2 : VerificationResult.UNSAT, 0.3: VerificationResult.SAT}, 0.2), + ({0.1 : VerificationResult.ERROR, 0.2 : VerificationResult.ERROR, 0.3: VerificationResult.UNSAT}, 0.3), + ({0.1 : VerificationResult.UNSAT, 0.2 : VerificationResult.ERROR, 0.3: VerificationResult.ERROR}, 0.1), + ]) + def test_compute_epsilon_value(self, verification_context, epsilon_verification_dict, expected_result): + + verifier = MockVerificationModule(epsilon_verification_dict) + estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=list(epsilon_verification_dict.keys()), verifier=verifier) + + epsilon_value_result = estimator.compute_epsilon_value(verification_context) + + assert epsilon_value_result.epsilon == expected_result + + def test_get_next_epsilon(self, epsilon_verification_dict): + verifier = MockVerificationModule(epsilon_verification_dict) + estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=list(epsilon_verification_dict.keys()), verifier=verifier) + + check_normal = estimator.get_next_epsilon(0,5,10) + assert check_normal in {2, 7} + assert estimator.get_next_epsilon(5,5,6) == 6 + check_extreme = estimator.get_next_epsilon(0,1,2) + assert check_extreme in {0,2} + + + + + + \ No newline at end of file From 541a48db315c23785286afe68c29b2e311751b17 Mon Sep 17 00:00:00 2001 From: AWbosman Date: Tue, 17 Sep 2024 11:58:04 +0200 Subject: [PATCH 3/3] tested and tests added --- .../quartered_binary_search_epsilon_value_estimator.py | 2 +- scripts/create_robustness_dist_on_pytorch_dataset.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py b/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py index 885b3b9..f5fa651 100644 --- a/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py +++ b/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py @@ -4,7 +4,7 @@ logger = logging.getLogger(__name__) import torch from robustness_experiment_box.verification_module.verification_module import VerificationModule -from robustness_experiment_box.epsilon_value_estimator.epsilon_value_estimator import BinarySearchEpsilonValueEstimator +from robustness_experiment_box.epsilon_value_estimator.binary_search_epsilon_value_estimator import BinarySearchEpsilonValueEstimator from robustness_experiment_box.database.verification_context import VerificationContext from robustness_experiment_box.database.verification_result import VerificationResult from robustness_experiment_box.database.epsilon_status import EpsilonStatus diff --git a/scripts/create_robustness_dist_on_pytorch_dataset.py b/scripts/create_robustness_dist_on_pytorch_dataset.py index e6a5a96..d87f699 100644 --- a/scripts/create_robustness_dist_on_pytorch_dataset.py +++ b/scripts/create_robustness_dist_on_pytorch_dataset.py @@ -14,6 +14,7 @@ from robustness_experiment_box.dataset_sampler.predictions_based_sampler import PredictionsBasedSampler from robustness_experiment_box.epsilon_value_estimator.epsilon_value_estimator import EpsilonValueEstimator from robustness_experiment_box.epsilon_value_estimator.binary_search_epsilon_value_estimator import BinarySearchEpsilonValueEstimator +from robustness_experiment_box.epsilon_value_estimator.quartered_binary_search_epsilon_value_estimator import QuarteredBinarySearchEpsilonValueEstimator from robustness_experiment_box.verification_module.auto_verify_module import AutoVerifyModule from robustness_experiment_box.database.dataset.experiment_dataset import ExperimentDataset from robustness_experiment_box.database.dataset.pytorch_experiment_dataset import PytorchExperimentDataset @@ -57,7 +58,7 @@ def main(): experiment_name = "nnenum_one2one" verifier = AutoVerifyModule(verifier=Nnenum(), property_generator=One2OnePropertyGenerator(target_class=1),timeout=timeout) - epsilon_value_estimator = BinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_list.copy(), verifier=verifier) + epsilon_value_estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_list.copy(), verifier=verifier) dataset_sampler = PredictionsBasedSampler(sample_correct_predictions=True) experiment_repository.initialize_new_experiment(experiment_name) experiment_repository.save_configuration(dict(