Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pycona/active_algorithms/gquacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
self.env.init_state(instance, oracle, verbose, metrics)

if len(self.env.instance.bias) == 0:
self.env.instance.construct_bias()
self.env.instance.construct_bias(X)

while True:
if self.env.verbose > 0:
Expand All @@ -66,6 +66,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
if self.env.verbose >= 1:
print(f"\nLearned {self.env.metrics.cl} constraints in "
f"{self.env.metrics.membership_queries_count} queries.")
self.env.instance.bias = []
return self.env.instance

self.env.metrics.increase_generation_time(gen_end - gen_start)
Expand Down
2 changes: 1 addition & 1 deletion pycona/active_algorithms/growacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
Y.append(x)
# add the constraints involving x and other added variables
if len(self.env.instance.bias) == 0:
self.env.instance.construct_bias_for_var(x, Y)
self.env.instance.construct_bias_for_vars(x, Y)
if verbose >= 3:
print(f"Added variable {x} in GrowAcq")
print("size of B in growacq: ", len(self.env.instance.bias))
Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/mquacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
self.env.init_state(instance, oracle, verbose, metrics)

if len(self.env.instance.bias) == 0:
self.env.instance.construct_bias()
self.env.instance.construct_bias(X)

while True:
if self.env.verbose >= 3:
Expand All @@ -61,6 +61,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
if self.env.verbose >= 1:
print(f"\nLearned {self.env.metrics.cl} constraints in "
f"{self.env.metrics.membership_queries_count} queries.")
self.env.instance.bias = []
return self.env.instance

self.env.metrics.increase_generation_time(gen_end - gen_start)
Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/mquacq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
self.cl_neighbours = np.zeros((len(self.env.instance.X), len(self.env.instance.X)), dtype=bool)

if len(self.env.instance.bias) == 0:
self.env.instance.construct_bias()
self.env.instance.construct_bias(X)

while True:
gen_start = time.time()
Expand All @@ -67,6 +67,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
if self.env.verbose >= 1:
print(f"\nLearned {self.env.metrics.cl} constraints in "
f"{self.env.metrics.membership_queries_count} queries.")
self.env.instance.bias = []
return self.env.instance

self.env.metrics.increase_generated_queries()
Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/pquacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
self.env.init_state(instance, oracle, verbose, metrics)

if len(self.env.instance.bias) == 0:
self.env.instance.construct_bias()
self.env.instance.construct_bias(X)

while True:
if self.env.verbose > 0:
Expand All @@ -61,6 +61,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
if self.env.verbose >= 1:
print(f"\nLearned {self.env.metrics.cl} constraints in "
f"{self.env.metrics.membership_queries_count} queries.")
self.env.instance.bias = []
return self.env.instance

self.env.metrics.increase_generation_time(gen_end - gen_start)
Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/quacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
self.env.init_state(instance, oracle, verbose, metrics)

if len(self.env.instance.bias) == 0:
self.env.instance.construct_bias()
self.env.instance.construct_bias(X)

while True:
if self.env.verbose > 2:
Expand All @@ -57,6 +57,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
if self.env.verbose >= 1:
print(f"\nLearned {self.env.metrics.cl} constraints in "
f"{self.env.metrics.membership_queries_count} queries.")
self.env.instance.bias = []
return self.env.instance

self.env.metrics.increase_generation_time(gen_end - gen_start)
Expand Down
32 changes: 20 additions & 12 deletions pycona/problem_instance/problem_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ def get_cpmpy_model(self):
warnings.warn("The model is empty, as no constraint is learned yet for this instance.")
return cp.Model(self._cl)

def construct_bias(self):
def construct_bias(self, X=None):
"""
Construct the bias (candidate constraints) for the problem instance.
"""
if X is None:
X = self.X

all_cons = []

X = list(self.X)

for relation in self.language:

abs_vars = get_scope(relation)
Expand All @@ -239,35 +239,43 @@ def construct_bias(self):
constraint = replace_variables(relation, replace_dict)
all_cons.append(constraint)

self.bias = all_cons
self.bias = list(set(all_cons) - set(self.cl) - set(self.excluded_cons))


def construct_bias_for_var(self, v1, X=None):
def construct_bias_for_vars(self, v1, X=None):
"""
Construct the bias (candidate constraints) for a specific variable.

:param v1: The variable for which to construct the bias.
:param v1: The variable for which to construct the bias. Can also be a list of variables.
:param X: The set of variables to consider, default is None.
"""

if not isinstance(v1, list):
v1 = [v1]

if X is None:
X = self.X
assert isinstance(X, list) and set(X).issubset(set(self.X)), "When using .construct_bias_for_var(), set parameter X must be a list of variables. Instead, got: " + str(X)

# Sort X based on variable names
X = sorted(X, key=lambda var: var.name)

all_cons = []
X = list(set(X) - {v1})

for relation in self.language:

abs_vars = get_scope(relation)

combs = combinations(X, len(abs_vars) - 1)
combs = combinations(X, len(abs_vars))

for comb in combs:
replace_dict = {abs_vars[0]: v1}
replace_dict = dict()
for i, v in enumerate(comb):
replace_dict[abs_vars[i + 1]] = v
replace_dict[abs_vars[i]] = v
constraint = replace_variables(relation, replace_dict)
all_cons.append(constraint)

self.bias = all_cons
self.bias = [c for c in all_cons if any(v in set(get_scope(c)) for v in v1)]
self.bias = list(set(self.bias) - set(self.cl) - set(self.excluded_cons))

def __str__(self):
"""
Expand Down
39 changes: 23 additions & 16 deletions pycona/query_generation/pqgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,32 @@ def blimit(self, blimit):
"""
self._blimit = blimit

def generate(self, Y=None):
def reset_partial(self):
"""
Reset the partial flag to False.
"""
self.partial = False

def generate(self, X=None):
"""
Generate a query using PQGen.

:return: A set of variables that form the query.
"""

if Y is None:
Y = self.env.instance.X
assert isinstance(Y, list), "When generating a query, Y must be a list of variables"

# Start time (for the cutoff time)
if X is None:
X = self.env.instance.X
B = get_con_subset(self.env.instance.bias, X)
# Start time (for the cutoff t)
t0 = time.time()

# Project down to only vars in scope of B
Y2 = frozenset(get_variables(self.env.instance.bias))

if len(Y2) < len(Y):
Y = Y2
Y = frozenset(get_variables(B))

lY = list(Y)

B = get_con_subset(self.env.instance.bias, Y)
Cl = get_con_subset(self.env.instance.cl, Y)

# If no constraints left in B, just return
if len(B) == 0:
return set()
Expand All @@ -104,7 +105,7 @@ def generate(self, Y=None):
if not self.partial and len(B) > self.blimit:

m = cp.Model(Cl)
flag = m.solve() # no time limit to ensure convergence
flag = m.solve(num_workers=8) # no time limit to ensure convergence

if flag and not all([c.value() for c in B]):
return lY
Expand All @@ -113,13 +114,16 @@ def generate(self, Y=None):

m = cp.Model(Cl)
s = cp.SolverLookup.get("ortools", m)

# We want at least one constraint to be violated to assure that each answer of the user
# will lead to new information
s += ~cp.all(B)

if self.env.verbose > 2:
print("Solving first without objective (to find at least one solution)...")

# Solve first without objective (to find at least one solution)
flag = s.solve()
flag = s.solve(num_workers=8)

t1 = time.time() - t0
if not flag or (t1 > self.time_limit):
Expand All @@ -140,7 +144,10 @@ def generate(self, Y=None):
# Run with the objective
s.maximize(objective)

flag2 = s.solve(time_limit=(self.time_limit - t1))
if self.env.verbose > 2:
print("Solving with objective...")

flag2 = s.solve(time_limit=(self.time_limit - t1), num_workers=8)

if flag2:
return lY
Expand Down