diff --git a/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py b/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py new file mode 100644 index 0000000..f5fa651 --- /dev/null +++ b/robustness_experiment_box/epsilon_value_estimator/quartered_binary_search_epsilon_value_estimator.py @@ -0,0 +1,79 @@ +import logging + +import time +logger = logging.getLogger(__name__) +import torch +from robustness_experiment_box.verification_module.verification_module import VerificationModule +from robustness_experiment_box.epsilon_value_estimator.binary_search_epsilon_value_estimator import BinarySearchEpsilonValueEstimator +from robustness_experiment_box.database.verification_context import VerificationContext +from robustness_experiment_box.database.verification_result import VerificationResult +from robustness_experiment_box.database.epsilon_status import EpsilonStatus + + +class QuarteredBinarySearchEpsilonValueEstimator(BinarySearchEpsilonValueEstimator): + + + def get_next_epsilon(self, midpoint:int, first:int, last:int) -> int: + random_number = torch.rand(1).item() + if random_number<= 0.5: + next = (first + midpoint) // 2 + if next == midpoint: + return first + else: + next = (midpoint + last) // 2 + if next == midpoint: + return last + + def binary_search(self, verification_context: VerificationContext, epsilon_status_list: list[EpsilonStatus]) -> float: + + if len(epsilon_status_list) == 1: + outcome = self.verifier.verify(verification_context, epsilon_status_list[0].value) + result = outcome.result + epsilon_status_list[0].time = outcome.took + epsilon_status_list[0].result = result + logger.debug(f'current epsilon value: {epsilon_status_list[0].result}, took: {epsilon_status_list[0].time}') + verification_context.save_result(epsilon_status_list[0]) + if result == VerificationResult.UNSAT: + return epsilon_status_list[0].value, self.get_smallest_sat(epsilon_status_list) + else: + return 0, self.get_smallest_sat(epsilon_status_list) + + first = 0 + last = len(epsilon_status_list) - 1 + midpoint = (first + last) // 2 + + while first<=last: + + if not epsilon_status_list[midpoint].result: + + outcome = self.verifier.verify(verification_context, epsilon_status_list[midpoint].value) + epsilon_status_list[midpoint].result = outcome.result + epsilon_status_list[midpoint].time = outcome.took + verification_context.save_result(epsilon_status_list[midpoint]) + logger.debug(f'current epsilon value: {epsilon_status_list[midpoint].result}, took: {epsilon_status_list[midpoint].time}') + + if epsilon_status_list[midpoint].result == VerificationResult.UNSAT: + first = midpoint + 1 + midpoint = (first + last) // 2 + elif epsilon_status_list[midpoint].result == VerificationResult.SAT: + last = midpoint - 1 + midpoint = (first + last) // 2 + else: + if len(epsilon_status_list)>3: + midpoint = self.get_next_epsilon(midpoint=midpoint, first=first,last=last) + epsilon_status_list.pop(midpoint) + last = last - 1 + else: + epsilon_status_list.pop(midpoint) + last = last - 1 + midpoint = (first + last) // 2 + + + + logger.debug(f"epsilon status list: {[(x.value, x.result, x.time) for x in epsilon_status_list]}") + + highest_unsat_value = self.get_highest_unsat(epsilon_status_list) + + smallest_sat_value = self.get_smallest_sat(epsilon_status_list) + + return highest_unsat_value, smallest_sat_value diff --git a/scripts/create_robustness_dist_on_pytorch_dataset.py b/scripts/create_robustness_dist_on_pytorch_dataset.py index e6a5a96..d87f699 100644 --- a/scripts/create_robustness_dist_on_pytorch_dataset.py +++ b/scripts/create_robustness_dist_on_pytorch_dataset.py @@ -14,6 +14,7 @@ from robustness_experiment_box.dataset_sampler.predictions_based_sampler import PredictionsBasedSampler from robustness_experiment_box.epsilon_value_estimator.epsilon_value_estimator import EpsilonValueEstimator from robustness_experiment_box.epsilon_value_estimator.binary_search_epsilon_value_estimator import BinarySearchEpsilonValueEstimator +from robustness_experiment_box.epsilon_value_estimator.quartered_binary_search_epsilon_value_estimator import QuarteredBinarySearchEpsilonValueEstimator from robustness_experiment_box.verification_module.auto_verify_module import AutoVerifyModule from robustness_experiment_box.database.dataset.experiment_dataset import ExperimentDataset from robustness_experiment_box.database.dataset.pytorch_experiment_dataset import PytorchExperimentDataset @@ -57,7 +58,7 @@ def main(): experiment_name = "nnenum_one2one" verifier = AutoVerifyModule(verifier=Nnenum(), property_generator=One2OnePropertyGenerator(target_class=1),timeout=timeout) - epsilon_value_estimator = BinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_list.copy(), verifier=verifier) + epsilon_value_estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_list.copy(), verifier=verifier) dataset_sampler = PredictionsBasedSampler(sample_correct_predictions=True) experiment_repository.initialize_new_experiment(experiment_name) experiment_repository.save_configuration(dict( diff --git a/tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py b/tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py new file mode 100644 index 0000000..fec9d63 --- /dev/null +++ b/tests/test_epsilon_value_estimator/test_quartered_binary_search_epsilon_value_estimator.py @@ -0,0 +1,63 @@ +import pytest + +from robustness_experiment_box.database.verification_result import VerificationResult +from robustness_experiment_box.database.epsilon_value_result import EpsilonValueResult +from robustness_experiment_box.epsilon_value_estimator.binary_search_epsilon_value_estimator import QuarteredBinarySearchEpsilonValueEstimator + +from tests.test_epsilon_value_estimator.conftest import MockVerificationModule + + +#TODO: adjust this for quartered +class TestBinarySearchEpsilonValueEstimator: + + + def test_verifier_gets_called(self, mocker, verification_context): + verification_module = MockVerificationModule(None) + verifier = mocker.Mock(verification_module) + estimator = QuarteredBinarySearchEpsilonValueEstimator([0.1], verifier=verifier) + + epsilon_value_result = estimator.compute_epsilon_value(verification_context) + + verifier.verify.assert_called() + + def test_result_class_returned(self, verification_context): + verifier = MockVerificationModule({0.1 : VerificationResult.SAT}) + estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=[0.1], verifier=verifier) + + epsilon_value_result = estimator.compute_epsilon_value(verification_context) + + assert isinstance(epsilon_value_result, EpsilonValueResult) + + @pytest.mark.parametrize("epsilon_verification_dict, expected_result", [ + ({0.1 : VerificationResult.ERROR, 0.2 : VerificationResult.ERROR, 0.3: VerificationResult.ERROR}, 0.), + ({0.1 : VerificationResult.TIMEOUT, 0.2 : VerificationResult.TIMEOUT, 0.3: VerificationResult.TIMEOUT}, 0.), + ({0.1 : VerificationResult.SAT, 0.2 : VerificationResult.SAT, 0.3: VerificationResult.SAT}, 0), + ({0.1 : VerificationResult.UNSAT, 0.2 : VerificationResult.SAT, 0.3: VerificationResult.SAT}, 0.1), + ({0.1 : VerificationResult.UNSAT, 0.2 : VerificationResult.UNSAT, 0.3: VerificationResult.SAT}, 0.2), + ({0.1 : VerificationResult.ERROR, 0.2 : VerificationResult.ERROR, 0.3: VerificationResult.UNSAT}, 0.3), + ({0.1 : VerificationResult.UNSAT, 0.2 : VerificationResult.ERROR, 0.3: VerificationResult.ERROR}, 0.1), + ]) + def test_compute_epsilon_value(self, verification_context, epsilon_verification_dict, expected_result): + + verifier = MockVerificationModule(epsilon_verification_dict) + estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=list(epsilon_verification_dict.keys()), verifier=verifier) + + epsilon_value_result = estimator.compute_epsilon_value(verification_context) + + assert epsilon_value_result.epsilon == expected_result + + def test_get_next_epsilon(self, epsilon_verification_dict): + verifier = MockVerificationModule(epsilon_verification_dict) + estimator = QuarteredBinarySearchEpsilonValueEstimator(epsilon_value_list=list(epsilon_verification_dict.keys()), verifier=verifier) + + check_normal = estimator.get_next_epsilon(0,5,10) + assert check_normal in {2, 7} + assert estimator.get_next_epsilon(5,5,6) == 6 + check_extreme = estimator.get_next_epsilon(0,1,2) + assert check_extreme in {0,2} + + + + + + \ No newline at end of file