diff --git a/pycona/active_algorithms/gquacq.py b/pycona/active_algorithms/gquacq.py index d6b0d55..34bf3da 100644 --- a/pycona/active_algorithms/gquacq.py +++ b/pycona/active_algorithms/gquacq.py @@ -21,7 +21,7 @@ class GQuAcq(AlgorithmCAInteractive): def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10): """ - Initialize the PQuAcq algorithm with an optional constraint acquisition environment. + Initialize the GQuAcq algorithm with an optional constraint acquisition environment. :param ca_env: An instance of ActiveCAEnv, default is None. : param GQmax: maximum number of generalization queries @@ -30,16 +30,21 @@ def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10): self._negativeQ = [] self._qg_max = qg_max - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None): + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): """ - Learn constraints using the QuAcq algorithm by generating queries and analyzing the results. + Learn constraints using the GQuAcq algorithm by generating queries and analyzing the results. :param instance: the problem instance to acquire the constraints for :param oracle: An instance of Oracle, default is to use the user as the oracle. :param verbose: Verbosity level, default is 0. :param metrics: statistics logger during learning + :param X: The set of variables to consider, default is None. :return: the learned instance """ + if X is None: + X = instance.X + assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables" + self.env.init_state(instance, oracle, verbose, metrics) if len(self.env.instance.bias) == 0: @@ -52,8 +57,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos print("Number of Queries: ", self.env.metrics.membership_queries_count) gen_start = time.time() - Y = self.env.run_query_generation() - gen_end = time.time() + Y = self.env.run_query_generation(X) + gen_end = time.time() if len(Y) == 0: # if no query can be generated it means we have (prematurely) converged to the target network ----- diff --git a/pycona/active_algorithms/growacq.py b/pycona/active_algorithms/growacq.py index eabef1e..ff3d55e 100644 --- a/pycona/active_algorithms/growacq.py +++ b/pycona/active_algorithms/growacq.py @@ -26,7 +26,7 @@ def __init__(self, ca_env: ActiveCAEnv = None, inner_algorithm: AlgorithmCAInter super().__init__(env) self.inner_algorithm = inner_algorithm if inner_algorithm is not None else MQuAcq2(ca_env) - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None): + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): """ Learn constraints by incrementally adding variables and using the inner algorithm to learn constraints for each added variable. @@ -34,9 +34,14 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos :param instance: the problem instance to acquire the constraints for :param oracle: An instance of Oracle, default is to use the user as the oracle. :param verbose: Verbosity level, default is 0. + :param X: The set of variables to consider, default is None. :param metrics: statistics logger during learning :return: the learned instance """ + if X is None: + X = instance.X + assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables" + self.env.init_state(instance, oracle, verbose, metrics) if verbose >= 1: @@ -44,21 +49,22 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos self.inner_algorithm.env = copy.copy(self.env) - self.env.instance.X = [] + Y = [] - n_vars = len(self.env.instance.variables.flat) - for x in self.env.instance.variables.flat: + n_vars = len(X) + for x in X: # we 'grow' the inner bias by adding one extra variable at a time - self.env.instance.X.append(x) + Y.append(x) # add the constraints involving x and other added variables - self.env.instance.construct_bias_for_var(x) + if len(self.env.instance.bias) == 0: + self.env.instance.construct_bias_for_var(x, Y) if verbose >= 3: print(f"Added variable {x} in GrowAcq") print("size of B in growacq: ", len(self.env.instance.bias)) if verbose >= 2: - print(f"\nGrowAcq: calling inner_algorithm for {len(self.env.instance.X)}/{n_vars} variables") - self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, metrics=self.env.metrics) + print(f"\nGrowAcq: calling inner_algorithm for {len(Y)}/{n_vars} variables") + self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, X=Y, metrics=self.env.metrics) if verbose >= 3: print("C_L: ", len(self.env.instance.cl)) diff --git a/pycona/active_algorithms/mquacq.py b/pycona/active_algorithms/mquacq.py index 1e398f4..e9db506 100644 --- a/pycona/active_algorithms/mquacq.py +++ b/pycona/active_algorithms/mquacq.py @@ -21,7 +21,7 @@ def __init__(self, ca_env: ActiveCAEnv = None): """ super().__init__(ca_env) - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None): + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): """ Learn constraints using the modified QuAcq algorithm by generating queries and analyzing the results. @@ -29,8 +29,13 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos :param oracle: An instance of Oracle, default is to use the user as the oracle. :param verbose: Verbosity level, default is 0. :param metrics: statistics logger during learning + :param X: The set of variables to consider, default is None. :return: the learned instance """ + if X is None: + X = instance.X + assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables" + self.env.init_state(instance, oracle, verbose, metrics) if len(self.env.instance.bias) == 0: @@ -47,7 +52,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos # generate e in D^X accepted by C_l and rejected by B gen_start = time.time() - Y = self.env.run_query_generation() + Y = self.env.run_query_generation(X) gen_end = time.time() if len(Y) == 0: diff --git a/pycona/active_algorithms/mquacq2.py b/pycona/active_algorithms/mquacq2.py index e5e805b..2890713 100644 --- a/pycona/active_algorithms/mquacq2.py +++ b/pycona/active_algorithms/mquacq2.py @@ -31,7 +31,7 @@ def __init__(self, ca_env: ActiveCAEnv = None, *, perform_analyzeAndLearn: bool self.cl_neighbours = None self.hashX = None - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None): + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): """ Learn constraints using the modified QuAcq algorithm by generating queries and analyzing the results. @@ -39,8 +39,13 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos :param oracle: An instance of Oracle, default is to use the user as the oracle. :param verbose: Verbosity level, default is 0. :param metrics: statistics logger during learning + :param X: The set of variables to consider, default is None. :return: the learned instance """ + if X is None: + X = instance.X + assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables" + self.env.init_state(instance, oracle, verbose, metrics) # Hash the variables @@ -52,7 +57,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos while True: gen_start = time.time() - Y = self.env.run_query_generation() + Y = self.env.run_query_generation(X) gen_end = time.time() self.env.metrics.increase_generation_time(gen_end - gen_start) diff --git a/pycona/active_algorithms/pquacq.py b/pycona/active_algorithms/pquacq.py index 1bbea05..d5ab2d3 100644 --- a/pycona/active_algorithms/pquacq.py +++ b/pycona/active_algorithms/pquacq.py @@ -25,7 +25,7 @@ def __init__(self, ca_env: ActiveCAEnv = None): """ super().__init__(ca_env) - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None): + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): """ Learn constraints using the QuAcq algorithm by generating queries and analyzing the results. @@ -33,8 +33,13 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos :param oracle: An instance of Oracle, default is to use the user as the oracle. :param verbose: Verbosity level, default is 0. :param metrics: statistics logger during learning + :param X: The set of variables to consider, default is None. :return: the learned instance """ + if X is None: + X = instance.X + assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables" + self.env.init_state(instance, oracle, verbose, metrics) if len(self.env.instance.bias) == 0: @@ -47,7 +52,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos print("Number of Queries: ", self.env.metrics.membership_queries_count) gen_start = time.time() - Y = self.env.run_query_generation() + Y = self.env.run_query_generation(X) gen_end = time.time() if len(Y) == 0: diff --git a/pycona/active_algorithms/quacq.py b/pycona/active_algorithms/quacq.py index a71cb67..36838e6 100644 --- a/pycona/active_algorithms/quacq.py +++ b/pycona/active_algorithms/quacq.py @@ -21,7 +21,7 @@ def __init__(self, ca_env: ActiveCAEnv = None): """ super().__init__(ca_env) - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None): + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): """ Learn constraints using the QuAcq algorithm by generating queries and analyzing the results. @@ -29,8 +29,13 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos :param oracle: An instance of Oracle, default is to use the user as the oracle. :param verbose: Verbosity level, default is 0. :param metrics: statistics logger during learning + :param X: The set of variables to consider, default is None. :return: the learned instance """ + if X is None: + X = instance.X + assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables" + self.env.init_state(instance, oracle, verbose, metrics) if len(self.env.instance.bias) == 0: @@ -43,7 +48,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos print("Number of Queries: ", self.env.metrics.membership_queries_count) gen_start = time.time() - Y = self.env.run_query_generation() + Y = self.env.run_query_generation(X) gen_end = time.time() if len(Y) == 0: diff --git a/pycona/benchmarks/exam_timetabling.py b/pycona/benchmarks/exam_timetabling.py index 07cff1f..bb6cca4 100644 --- a/pycona/benchmarks/exam_timetabling.py +++ b/pycona/benchmarks/exam_timetabling.py @@ -2,7 +2,7 @@ from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar - +from cpmpy.transformations.normalize import toplevel_list def day_of_exam(course, slots_per_day): return course // slots_per_day @@ -27,7 +27,7 @@ def construct_examtt_simple(nsemesters=9, courses_per_semester=6, slots_per_day= for row in courses: model += cp.AllDifferent(day_of_exam(row, slots_per_day)).decompose() - C_T = list(model.constraints) + C_T = list(set(toplevel_list(model.constraints))) if model.solve(): courses.clear() diff --git a/pycona/benchmarks/job_shop_scheduling.py b/pycona/benchmarks/job_shop_scheduling.py index 0cb7fe4..7e65b50 100644 --- a/pycona/benchmarks/job_shop_scheduling.py +++ b/pycona/benchmarks/job_shop_scheduling.py @@ -3,7 +3,7 @@ import cpmpy as cp import numpy as np from cpmpy.expressions.utils import all_pairs - +from cpmpy.transformations.normalize import toplevel_list from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar @@ -56,7 +56,7 @@ def construct_job_shop_scheduling_problem(n_jobs, machines, horizon, seed=0): for (j1, t1), (j2, t2) in all_pairs(zip(*tasks_on_mach)): m += (end[j1, t1] <= start[j2, t2]) | (end[j2, t2] <= start[j1, t1]) - C_T = list(model.constraints) + C_T = list(set(toplevel_list(model.constraints))) max_duration = max(duration) diff --git a/pycona/benchmarks/jsudoku.py b/pycona/benchmarks/jsudoku.py index efce234..8bdc195 100644 --- a/pycona/benchmarks/jsudoku.py +++ b/pycona/benchmarks/jsudoku.py @@ -1,5 +1,5 @@ import cpmpy as cp - +from cpmpy.transformations.normalize import toplevel_list from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar @@ -49,6 +49,6 @@ def construct_jsudoku(): instance = ProblemInstance(variables=grid, params=parameters, language=lang, name="jsudoku") - oracle = ConstraintOracle(C_T) + oracle = ConstraintOracle(list(set(toplevel_list(C_T)))) return instance, oracle diff --git a/pycona/benchmarks/murder.py b/pycona/benchmarks/murder.py index 3b97e24..efe283a 100644 --- a/pycona/benchmarks/murder.py +++ b/pycona/benchmarks/murder.py @@ -1,6 +1,6 @@ import cpmpy as cp - +from cpmpy.transformations.normalize import toplevel_list from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar @@ -45,6 +45,6 @@ def construct_murder_problem(): instance = ProblemInstance(variables=grid, language=lang, name="murder") - oracle = ConstraintOracle(C_T) + oracle = ConstraintOracle(list(set(toplevel_list(C_T)))) return instance, oracle diff --git a/pycona/benchmarks/nurse_rostering.py b/pycona/benchmarks/nurse_rostering.py index af9021b..6d52fbf 100644 --- a/pycona/benchmarks/nurse_rostering.py +++ b/pycona/benchmarks/nurse_rostering.py @@ -1,5 +1,5 @@ import cpmpy as cp - +from cpmpy.transformations.normalize import toplevel_list from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar @@ -30,7 +30,7 @@ def construct_nurse_rostering(shifts_per_day=3, num_days=5, num_nurses=8, nurses if not model.solve(): raise Exception("The problem has no solution") - C_T = list(model.constraints) + C_T = list(set(toplevel_list(model.constraints))) # Create the language: AV = absvar(2) # create abstract vars - as many as maximum arity diff --git a/pycona/benchmarks/sudoku.py b/pycona/benchmarks/sudoku.py index 9c4f3e8..93233a8 100644 --- a/pycona/benchmarks/sudoku.py +++ b/pycona/benchmarks/sudoku.py @@ -1,5 +1,5 @@ import cpmpy as cp - +from cpmpy.transformations.normalize import toplevel_list from ..answering_queries.constraint_oracle import ConstraintOracle from ..problem_instance import ProblemInstance, absvar @@ -29,7 +29,7 @@ def construct_sudoku(block_size_row, block_size_col, grid_size): for j in range(0, grid_size, block_size_col): model += cp.AllDifferent(grid[i:i + block_size_row, j:j + block_size_col]).decompose() # python's indexing - C_T = list(model.constraints) + C_T = list(set(toplevel_list(model.constraints))) # Create the language: AV = absvar(2) # create abstract vars - as many as maximum arity diff --git a/pycona/ca_environment/acive_ca_proba.py b/pycona/ca_environment/acive_ca_proba.py index 78af112..c76f8be 100644 --- a/pycona/ca_environment/acive_ca_proba.py +++ b/pycona/ca_environment/acive_ca_proba.py @@ -55,12 +55,12 @@ def init_state(self, instance, oracle, verbose, metrics=None): else: self._bias_proba = {c: 0.01 for c in self.instance.bias} - def run_query_generation(self): + def run_query_generation(self, X=None): """ Run the query generation process. """ if self.training_frequency > 0 and len(set(self.datasetY)) == 2: self._train_classifier() self._predict_bias_proba() - return super().run_query_generation() + return super().run_query_generation(X) def run_find_scope(self, Y): """ Run the find scope process. """ diff --git a/pycona/ca_environment/active_ca.py b/pycona/ca_environment/active_ca.py index b6c0f01..fbabcf3 100644 --- a/pycona/ca_environment/active_ca.py +++ b/pycona/ca_environment/active_ca.py @@ -50,9 +50,9 @@ def init_state(self, instance, oracle, verbose, metrics=None): self.find_scope.ca = self self.findc.ca = self - def run_query_generation(self): + def run_query_generation(self, Y=None): """ Run the query generation process. """ - Y = self.qgen.generate() + Y = self.qgen.generate(Y) return Y def run_find_scope(self, Y): diff --git a/pycona/problem_instance/problem_instance.py b/pycona/problem_instance/problem_instance.py index b9d27a3..d65ab5e 100644 --- a/pycona/problem_instance/problem_instance.py +++ b/pycona/problem_instance/problem_instance.py @@ -241,15 +241,19 @@ def construct_bias(self): self.bias = all_cons - def construct_bias_for_var(self, v1): + def construct_bias_for_var(self, v1, X=None): """ Construct the bias (candidate constraints) for a specific variable. :param v1: The variable for which to construct the bias. + :param X: The set of variables to consider, default is None. """ + if X is None: + X = self.X + assert isinstance(X, list) and set(X).issubset(set(self.X)), "When using .construct_bias_for_var(), set parameter X must be a list of variables. Instead, got: " + str(X) all_cons = [] - X = list(set(self.X) - {v1}) + X = list(set(X) - {v1}) for relation in self.language: abs_vars = get_scope(relation) diff --git a/pycona/query_generation/pqgen.py b/pycona/query_generation/pqgen.py index 20051c0..b68d64d 100644 --- a/pycona/query_generation/pqgen.py +++ b/pycona/query_generation/pqgen.py @@ -67,25 +67,30 @@ def blimit(self, blimit): """ self._blimit = blimit - def generate(self): + def generate(self, Y=None): """ Generate a query using PQGen. :return: A set of variables that form the query. """ - # Start time (for the cutoff t) + + if Y is None: + Y = self.env.instance.X + assert isinstance(Y, list), "When generating a query, Y must be a list of variables" + + # Start time (for the cutoff time) t0 = time.time() # Project down to only vars in scope of B - Y = frozenset(get_variables(self.env.instance.bias)) + Y2 = frozenset(get_variables(self.env.instance.bias)) + + if len(Y2) < len(Y): + Y = Y2 + lY = list(Y) - if len(Y) == len(self.env.instance.X): - B = self.env.instance.bias - Cl = self.env.instance.cl - else: - B = get_con_subset(self.env.instance.bias, Y) - Cl = get_con_subset(self.env.instance.cl, Y) + B = get_con_subset(self.env.instance.bias, Y) + Cl = get_con_subset(self.env.instance.cl, Y) # If no constraints left in B, just return if len(B) == 0: diff --git a/pycona/query_generation/qgen.py b/pycona/query_generation/qgen.py index 7aeddaa..e4d94f7 100644 --- a/pycona/query_generation/qgen.py +++ b/pycona/query_generation/qgen.py @@ -2,7 +2,7 @@ from ..ca_environment.active_ca import ActiveCAEnv import cpmpy as cp from cpmpy.solvers.solver_interface import ExitStatus - +from ..utils import get_con_subset from .qgen_core import QGenBase @@ -22,25 +22,32 @@ def __init__(self, ca_env: ActiveCAEnv = None, time_limit=600): @abstractmethod - def generate(self): + def generate(self, Y=None): """ A basic version of query generation for small problems. May lead to premature convergence, so generally not used. :return: A set of variables that form the query. """ - if len(self.env.instance.bias) == 0: - return False + if Y is None: + Y = self.env.instance.X + assert isinstance(Y, list), "When generating a query, Y must be a list of variables" + + B = get_con_subset(self.env.instance.bias, Y) + Cl = get_con_subset(self.env.instance.cl, Y) + + if len(B) == 0: + return set() # B are taken into account as soft constraints that we do not want to satisfy (i.e., that we want to violate) - m = cp.Model(self.env.instance.cl) # could use to-be-implemented m.copy() here... + m = cp.Model(Cl) # could use to-be-implemented m.copy() here... # Get the amount of satisfied constraints from B - objective = sum([c for c in self.env.instance.bias]) + objective = sum([c for c in B]) # We want at least one constraint to be violated to assure that each answer of the # user will reduce the set of candidates - m += objective < len(self.env.instance.bias) + m += objective < len(B) s = cp.SolverLookup.get("ortools", m) flag = s.solve(time_limit=self.time_limit) @@ -50,4 +57,4 @@ def generate(self): self.env.converged = 0 return set() - return self.env.instance.X + return Y diff --git a/pycona/query_generation/qgen_core.py b/pycona/query_generation/qgen_core.py index 74a7d9e..25870ed 100644 --- a/pycona/query_generation/qgen_core.py +++ b/pycona/query_generation/qgen_core.py @@ -18,7 +18,7 @@ def __init__(self, ca_env: ActiveCAEnv = None, time_limit=2): self._time_limit = time_limit @abstractmethod - def generate(self): + def generate(self, Y=None): """ Method that all QGen implementations must implement to generate a query. """ diff --git a/pycona/query_generation/tqgen.py b/pycona/query_generation/tqgen.py index 3f665f2..e903216 100644 --- a/pycona/query_generation/tqgen.py +++ b/pycona/query_generation/tqgen.py @@ -60,28 +60,31 @@ def lamda(self, lamda): """Set the lamda parameter of TQGen.""" self._lamda = lamda - def generate(self): + def generate(self, Y=None): """ Generate a query using TQGen. :return: A list of variables that form the query. """ + if Y is None: + Y = self.env.instance.X + assert isinstance(Y, list), "When generating a query, Y must be a list of variables" + if self._lamda is None: self._lamda = len(self._env.instance.X) ttime = 0 - - bias = self.env.instance.bias - cl = self.env.instance.cl + bias = get_con_subset(self.env.instance.bias, Y) + cl = get_con_subset(self.env.instance.cl, Y) while (ttime < self.time_limit) and (len(bias) > 0): t = min([self.tau, self.time_limit - ttime]) l = max([self.lamda, get_min_arity(bias)]) - Y = find_suitable_vars_subset2(l, bias, self.env.instance.X) + Y2 = find_suitable_vars_subset2(l, bias, Y) - B = get_con_subset(bias, Y) - Cl = get_con_subset(cl, Y) + B = get_con_subset(bias, Y2) + Cl = get_con_subset(cl, Y2) m = cp.Model(Cl) s = cp.SolverLookup.get("ortools", m) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 88320bf..7a29341 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -75,3 +75,112 @@ def test_proba_growacq(self, bench, inner_alg, classifier): learned_instance = ca_system.learn(instance=instance, oracle=oracle) assert len(learned_instance.cl) > 0 assert learned_instance.get_cpmpy_model().solve() + + @pytest.mark.parametrize(("bench", "algorithm"), _generate_base_inputs(), ids=str) + def test_base_algorithms_with_initial_cl(self, bench, algorithm): + (instance, oracle) = bench + # Create a copy of the instance to avoid modifying the original + instance = instance.copy() + + # Get some constraints from the oracle's constraint set + initial_constraints = oracle.constraints[:len(oracle.constraints)//2] # Take half of the constraints + instance.cl.extend(initial_constraints) + initial_cl_size = len(instance.cl) + + ca_system = algorithm + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) == initial_cl_size*2 + assert learned_instance.get_cpmpy_model().solve() + + @pytest.mark.parametrize(("bench", "algorithm", "classifier"), _generate_proba_inputs(), ids=str) + def test_proba_with_initial_cl(self, bench, algorithm, classifier): + env = ca.ProbaActiveCAEnv(classifier=classifier) + (instance, oracle) = bench + # Create a copy of the instance to avoid modifying the original + instance = instance.copy() + + # Get some constraints from the oracle's constraint set + initial_constraints = oracle.constraints[:len(oracle.constraints)//2] # Take half of the constraints + instance.cl.extend(initial_constraints) + initial_cl_size = len(instance.cl) + + ca_system = algorithm + ca_system.env = env + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) == initial_cl_size*2 + assert learned_instance.get_cpmpy_model().solve() + + @pytest.mark.parametrize(("bench", "algorithm"), _generate_base_inputs(), ids=str) + def test_base_algorithms_with_bias(self, bench, algorithm): + (instance, oracle) = bench + # Create a copy of the instance to avoid modifying the original + instance = instance.copy() + + # Generate bias constraints for the instance + instance.construct_bias() + # Separate constraints into those from oracle and others + oracle_constraints = set(oracle.constraints) + other_constraints = [c for c in instance.bias if c not in oracle_constraints] + # Keep all oracle constraints and half of the other constraints + instance.bias = list(oracle_constraints) + other_constraints[:len(other_constraints)//2] + + ca_system = algorithm + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) > 0 + assert learned_instance.get_cpmpy_model().solve() + + @pytest.mark.parametrize(("bench", "algorithm", "classifier"), _generate_proba_inputs(), ids=str) + def test_proba_with_bias(self, bench, algorithm, classifier): + env = ca.ProbaActiveCAEnv(classifier=classifier) + (instance, oracle) = bench + # Create a copy of the instance to avoid modifying the original + instance = instance.copy() + + # Generate bias constraints for the instance + instance.construct_bias() + # Separate constraints into those from oracle and others + oracle_constraints = set(oracle.constraints) + other_constraints = [c for c in instance.bias if c not in oracle_constraints] + # Keep all oracle constraints and half of the other constraints + instance.bias = list(oracle_constraints) + other_constraints[:len(other_constraints)//2] + + ca_system = algorithm + ca_system.env = env + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) > 0 + assert learned_instance.get_cpmpy_model().solve() + + @pytest.mark.parametrize(("bench", "inner_alg"), _generate_base_inputs(), ids=str) + def test_growacq_with_initial_cl(self, bench, inner_alg): + (instance, oracle) = bench + # Create a copy of the instance to avoid modifying the original + instance = instance.copy() + + # Get some constraints from the oracle's constraint set + initial_constraints = oracle.constraints[:len(oracle.constraints)//2] # Take half of the constraints + instance.cl.extend(initial_constraints) + initial_cl_size = len(instance.cl) + + ca_system = ca.GrowAcq(inner_algorithm=inner_alg) + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) == initial_cl_size*2 + assert learned_instance.get_cpmpy_model().solve() + + @pytest.mark.parametrize(("bench", "inner_alg"), _generate_base_inputs(), ids=str) + def test_growacq_with_bias(self, bench, inner_alg): + (instance, oracle) = bench + # Create a copy of the instance to avoid modifying the original + instance = instance.copy() + + # Generate bias constraints for the instance + instance.construct_bias() + # Separate constraints into those from oracle and others + oracle_constraints = set(oracle.constraints) + other_constraints = [c for c in instance.bias if c not in oracle_constraints] + # Keep all oracle constraints and half of the other constraints + instance.bias = list(oracle_constraints) + other_constraints[:len(other_constraints)//2] + + ca_system = ca.GrowAcq(inner_algorithm=inner_alg) + learned_instance = ca_system.learn(instance=instance, oracle=oracle) + assert len(learned_instance.cl) > 0 + assert learned_instance.get_cpmpy_model().solve() \ No newline at end of file