diff --git a/pycona/active_algorithms/gquacq.py b/pycona/active_algorithms/gquacq.py index 6e22c15..99e1361 100644 --- a/pycona/active_algorithms/gquacq.py +++ b/pycona/active_algorithms/gquacq.py @@ -1,6 +1,6 @@ import time +import cpmpy as cp -import networkx as nx from cpmpy.transformations.get_variables import get_variables from .algorithm_core import AlgorithmCAInteractive @@ -90,6 +90,11 @@ def mineAsk(self, r): :param r: The index of a relation in gamma. :return: List of learned constraints. """ + try: + import networkx as nx + except ImportError: + raise ImportError("To use the predictAsk function of PQuAcq, networkx needs to be installed") + gq_counter = 0 C = [c for c in self.env.instance.cl if get_relation(c, self.env.instance.language) == r] @@ -110,19 +115,24 @@ def mineAsk(self, r): gen_flag = False B = [c for c in self.env.instance.bias if get_relation(c, self.env.instance.language) == r and frozenset(get_scope(c)).issubset(Y)] + D = [tuple([v.name for v in get_scope(c)]) for c in B] # missing edges that can be completed (exist in B) - # if already a subset of it was negative, or cannot be completed to a clique, continue to next - if not any(Y2.issubset(Y) for Y2 in self._negativeQ) and can_be_clique(G.subgraph(Y), D): - # if potentially generalizing leads to unsat, continue to next - new_CL = self.env.instance.cl.copy() - new_CL += B - if new_CL.solve() and self.env.ask_generalization_query(r, B): - gen_flag = True - self.env.add_to_cl(B) - else: - gq_counter += 1 - self._negativeQ.append(Y) + # If one of the following conditions is true, continue to next: + # already a subset of it was negative, cannot be completed to a clique, does not add any constraint or + # potentially generalizing leads to UNSAT + new_CL = self.env.instance.cl.copy() + new_CL += B + if any(Y2.issubset(Y) for Y2 in self._negativeQ) or not can_be_clique(G.subgraph(Y), D) or \ + len(B) > 0 or cp.Model(new_CL).solve(): + continue + + if self.env.ask_generalization_query(self.env.instance.language[r], B): + gen_flag = True + self.env.add_to_cl(B) + else: + gq_counter += 1 + self._negativeQ.append(Y) if not gen_flag: communities = nx.community.greedy_modularity_communities(G.subgraph(Y)) diff --git a/pycona/answering_queries/constraint_oracle.py b/pycona/answering_queries/constraint_oracle.py index dde5441..31f5e49 100644 --- a/pycona/answering_queries/constraint_oracle.py +++ b/pycona/answering_queries/constraint_oracle.py @@ -1,3 +1,4 @@ +import cpmpy as cp from cpmpy.transformations.normalize import toplevel_list from .oracle import Oracle @@ -46,19 +47,21 @@ def answer_membership_query(self, Y): # Need the oracle to answer based only on the constraints with a scope that is a subset of Y suboracle = get_con_subset(self.constraints, Y) - # Check if at least one constraint is violated or not return all([check_value(c) for c in suboracle]) def answer_recommendation_query(self, c): """ - Answer a recommendation query by checking if the recommended constraint is part of the constraints. + Answer a recommendation query by checking if the recommended constraint is part of the target set of + constraints, or logically implied by the constraints in the target set of constraints. :param c: The recommended constraint. :return: A boolean indicating if the recommended constraint is in the set of constraints. """ - # Check if the recommended constraint is in the set of constraints - return c in self.constraints + # Check if the recommended constraint is in the set of constraints or implied by them + m = cp.Model(self.constraints) + m += ~c + return not m.solve() def answer_generalization_query(self, C): """ diff --git a/pycona/benchmarks/exam_timetabling.py b/pycona/benchmarks/exam_timetabling.py index ba617be..07cff1f 100644 --- a/pycona/benchmarks/exam_timetabling.py +++ b/pycona/benchmarks/exam_timetabling.py @@ -12,12 +12,11 @@ def construct_examtt_simple(nsemesters=9, courses_per_semester=6, slots_per_day= """ :return: a ProblemInstance object, along with a constraint-based oracle """ - - total_courses = nsemesters * courses_per_semester total_slots = slots_per_day * days_for_exams - parameter_vars = ['nsemesters', 'courses_per_semester', 'slots_per_day', 'days_for_exams'] - parameters = {var_name: locals()[var_name] for var_name in parameter_vars} + parameters = {'nsemesters': nsemesters, 'courses_per_semester': courses_per_semester, + 'slots_per_day': slots_per_day, 'days_for_exams': days_for_exams} + # Variables courses = cp.intvar(1, total_slots, shape=(nsemesters, courses_per_semester), name="courses") @@ -31,7 +30,6 @@ def construct_examtt_simple(nsemesters=9, courses_per_semester=6, slots_per_day= C_T = list(model.constraints) if model.solve(): - solution = courses.value() courses.clear() else: print("no solution") diff --git a/pycona/benchmarks/sudoku.py b/pycona/benchmarks/sudoku.py index 17b4e67..9c4f3e8 100644 --- a/pycona/benchmarks/sudoku.py +++ b/pycona/benchmarks/sudoku.py @@ -41,5 +41,4 @@ def construct_sudoku(block_size_row, block_size_col, grid_size): oracle = ConstraintOracle(C_T) - print(len(C_T)) return instance, oracle diff --git a/pycona/ca_environment/active_ca.py b/pycona/ca_environment/active_ca.py index a240d7b..637267d 100644 --- a/pycona/ca_environment/active_ca.py +++ b/pycona/ca_environment/active_ca.py @@ -194,7 +194,8 @@ def ask_generalization_query(self, c, C): :param C: A list of constraints to which the generalization is applied. :return: The oracle's answer to the generalization query (True/False). """ - assert isinstance(c, Expression), "Generalization queries first input needs to be a constraint" + assert c in set(self.instance.language), "Generalization queries first input needs to be an expression from " \ + "the language" assert isinstance(C, list), "Generalization queries second input needs to be a list of constraints" assert all(isinstance(c1, Expression) for c1 in C), "Generalization queries second input needs to be " \ "a list of constraints" diff --git a/requirements.txt b/requirements.txt index 1af028c..482abbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ cpmpy>=0.9 -scikit-learn \ No newline at end of file +scikit-learn +networkx \ No newline at end of file